2024-10-31 15:54:33 +04:00
{
"cells": [
2024-10-31 15:56:19 +04:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Начало лабораторной\n",
"\n",
"Цены на кофе - https://www.kaggle.com/datasets/mayankanand2701/starbucks-stock-price-dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Атрибуты\n",
"\n",
"Date — Дата\n",
"\n",
"Open — Открытие\n",
"\n",
"High — Макс. цена\n",
"\n",
"Low — Мин. цена\n",
"\n",
"Close — Закрытие\n",
"\n",
"Adj Close — Скорректированная цена закрытия\n",
"\n",
"Volume — Объем торгов"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цели\n",
"\n",
"__1. Оценка волатильности акций:__\n",
"\n",
"\n",
"Описание: Прогнозировать волатильность акций на основе изменений в ценах открытий, максимума, минимума и объема торгов.\n",
"Целевая переменная: Разница между высокой и низкой ценой (High - Low). (среднее значение)\n",
"\n",
"__2. Прогнозирование цены закрытия акций:__\n",
"\n",
"\n",
"Описание: Оценить, какая будет цена закрытия акций Starbucks на следующий день или через несколько дней на основе исторических данных.\n",
"Целевая переменная: Цена закрытия (Close). (среднее значение)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-01 17:49:40 +04:00
"## Определение достижимого уровня качества модели для первой задачи "
2024-10-31 15:56:19 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__По дг о то вка да нных __"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Загрузка данных и создание целевой переменной"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 1,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'Volume': 14704589.99726232\n",
" Date Open High Low Close Adj Close Volume \\\n",
"0 1992-06-26 0.328125 0.347656 0.320313 0.335938 0.260703 224358400 \n",
"1 1992-06-29 0.339844 0.367188 0.332031 0.359375 0.278891 58732800 \n",
"2 1992-06-30 0.367188 0.371094 0.343750 0.347656 0.269797 34777600 \n",
"3 1992-07-01 0.351563 0.359375 0.339844 0.355469 0.275860 18316800 \n",
"4 1992-07-02 0.359375 0.359375 0.347656 0.355469 0.275860 13996800 \n",
"\n",
" above_average_volume volatility \n",
"0 1 0.027343 \n",
"1 1 0.035157 \n",
"2 1 0.027344 \n",
"3 1 0.019531 \n",
"4 0 0.011719 \n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"# Установим параметры для вывода\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Загружаем набор данных\n",
"df = pd.read_csv(\".//static//csv//Starbucks Dataset.csv\")\n",
"\n",
"# Устанавливаем случайное состояние\n",
"random_state = 42\n",
"\n",
"# Рассчитываем среднее значение объема\n",
"average_volume = df['Volume'].mean()\n",
"print(f\"Среднее значение поля 'Volume': {average_volume}\")\n",
"\n",
"# Создаем новую переменную, указывающую, превышает ли объем средний\n",
"df['above_average_volume'] = (df['Volume'] > average_volume).astype(int)\n",
"\n",
"# Рассчитываем волатильность (разницу между высокими и низкими значениями)\n",
"df['volatility'] = df['High'] - df['Low']\n",
"\n",
"# Выводим первые строки измененной таблицы для проверки\n",
"print(df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- above_average_close"
]
},
2024-10-31 15:54:33 +04:00
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 2,
2024-10-31 15:54:33 +04:00
"metadata": {},
2024-10-31 15:56:19 +04:00
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>7159</th>\n",
" <td>2020-11-27</td>\n",
" <td>98.480003</td>\n",
" <td>98.980003</td>\n",
" <td>98.279999</td>\n",
" <td>98.660004</td>\n",
" <td>91.604065</td>\n",
" <td>2169700</td>\n",
" <td>0</td>\n",
" <td>0.700004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4505</th>\n",
" <td>2010-05-14</td>\n",
" <td>13.630000</td>\n",
" <td>13.665000</td>\n",
" <td>13.090000</td>\n",
" <td>13.255000</td>\n",
" <td>10.329099</td>\n",
" <td>23081800</td>\n",
" <td>1</td>\n",
" <td>0.575000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>1994-02-24</td>\n",
" <td>0.710938</td>\n",
" <td>0.726563</td>\n",
" <td>0.695313</td>\n",
" <td>0.699219</td>\n",
" <td>0.542626</td>\n",
" <td>9264000</td>\n",
" <td>0</td>\n",
" <td>0.031250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1595</th>\n",
" <td>1998-10-19</td>\n",
" <td>2.371094</td>\n",
" <td>2.425781</td>\n",
" <td>2.277344</td>\n",
" <td>2.324219</td>\n",
" <td>1.803701</td>\n",
" <td>21284800</td>\n",
" <td>1</td>\n",
" <td>0.148437</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3676</th>\n",
" <td>2007-01-30</td>\n",
" <td>17.594999</td>\n",
" <td>17.680000</td>\n",
" <td>17.260000</td>\n",
" <td>17.280001</td>\n",
" <td>13.410076</td>\n",
" <td>28372200</td>\n",
" <td>1</td>\n",
" <td>0.420000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5976</th>\n",
" <td>2016-03-18</td>\n",
" <td>59.910000</td>\n",
" <td>60.450001</td>\n",
" <td>59.430000</td>\n",
" <td>59.700001</td>\n",
" <td>50.562347</td>\n",
" <td>14313600</td>\n",
" <td>0</td>\n",
" <td>1.020001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1305</th>\n",
" <td>1997-08-25</td>\n",
" <td>2.542969</td>\n",
" <td>2.703125</td>\n",
" <td>2.539063</td>\n",
" <td>2.679688</td>\n",
" <td>2.079561</td>\n",
" <td>28209600</td>\n",
" <td>1</td>\n",
" <td>0.164062</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6085</th>\n",
" <td>2016-08-23</td>\n",
" <td>56.169998</td>\n",
" <td>56.540001</td>\n",
" <td>56.000000</td>\n",
" <td>56.400002</td>\n",
" <td>48.101521</td>\n",
" <td>7827900</td>\n",
" <td>0</td>\n",
" <td>0.540001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5470</th>\n",
" <td>2014-03-17</td>\n",
" <td>37.404999</td>\n",
" <td>37.494999</td>\n",
" <td>36.910000</td>\n",
" <td>37.090000</td>\n",
" <td>30.569410</td>\n",
" <td>11019800</td>\n",
" <td>0</td>\n",
" <td>0.584999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5781</th>\n",
" <td>2015-06-10</td>\n",
" <td>51.799999</td>\n",
" <td>52.860001</td>\n",
" <td>51.660000</td>\n",
" <td>52.689999</td>\n",
" <td>44.214481</td>\n",
" <td>8003600</td>\n",
" <td>0</td>\n",
" <td>1.200001</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"7159 2020-11-27 98.480003 98.980003 98.279999 98.660004 91.604065 \n",
"4505 2010-05-14 13.630000 13.665000 13.090000 13.255000 10.329099 \n",
"421 1994-02-24 0.710938 0.726563 0.695313 0.699219 0.542626 \n",
"1595 1998-10-19 2.371094 2.425781 2.277344 2.324219 1.803701 \n",
"3676 2007-01-30 17.594999 17.680000 17.260000 17.280001 13.410076 \n",
"... ... ... ... ... ... ... \n",
"5976 2016-03-18 59.910000 60.450001 59.430000 59.700001 50.562347 \n",
"1305 1997-08-25 2.542969 2.703125 2.539063 2.679688 2.079561 \n",
"6085 2016-08-23 56.169998 56.540001 56.000000 56.400002 48.101521 \n",
"5470 2014-03-17 37.404999 37.494999 36.910000 37.090000 30.569410 \n",
"5781 2015-06-10 51.799999 52.860001 51.660000 52.689999 44.214481 \n",
"\n",
" Volume above_average_volume volatility \n",
"7159 2169700 0 0.700004 \n",
"4505 23081800 1 0.575000 \n",
"421 9264000 0 0.031250 \n",
"1595 21284800 1 0.148437 \n",
"3676 28372200 1 0.420000 \n",
"... ... ... ... \n",
"5976 14313600 0 1.020001 \n",
"1305 28209600 1 0.164062 \n",
"6085 7827900 0 0.540001 \n",
"5470 11019800 0 0.584999 \n",
"5781 8003600 0 1.200001 \n",
"\n",
"[6428 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_volume</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>7159</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4505</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1595</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3676</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5976</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1305</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6085</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5470</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5781</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_volume\n",
"7159 0\n",
"4505 1\n",
"421 0\n",
"1595 1\n",
"3676 1\n",
"... ...\n",
"5976 0\n",
"1305 1\n",
"6085 0\n",
"5470 0\n",
"5781 0\n",
"\n",
"[6428 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>312</th>\n",
" <td>1993-09-21</td>\n",
" <td>0.746094</td>\n",
" <td>0.753906</td>\n",
" <td>0.726563</td>\n",
" <td>0.734375</td>\n",
" <td>0.569909</td>\n",
" <td>8051200</td>\n",
" <td>0</td>\n",
" <td>0.027343</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6118</th>\n",
" <td>2016-10-10</td>\n",
" <td>53.529999</td>\n",
" <td>53.599998</td>\n",
" <td>53.270000</td>\n",
" <td>53.299999</td>\n",
" <td>45.457634</td>\n",
" <td>7224300</td>\n",
" <td>0</td>\n",
" <td>0.329998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1775</th>\n",
" <td>1999-07-08</td>\n",
" <td>3.132813</td>\n",
" <td>3.140625</td>\n",
" <td>3.046875</td>\n",
" <td>3.078125</td>\n",
" <td>2.388767</td>\n",
" <td>43104000</td>\n",
" <td>1</td>\n",
" <td>0.093750</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6621</th>\n",
" <td>2018-10-09</td>\n",
" <td>56.830002</td>\n",
" <td>59.700001</td>\n",
" <td>56.810001</td>\n",
" <td>57.709999</td>\n",
" <td>51.257065</td>\n",
" <td>24855700</td>\n",
" <td>1</td>\n",
" <td>2.890000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4363</th>\n",
" <td>2009-10-20</td>\n",
" <td>10.390000</td>\n",
" <td>10.475000</td>\n",
" <td>10.190000</td>\n",
" <td>10.265000</td>\n",
" <td>7.966110</td>\n",
" <td>11845000</td>\n",
" <td>0</td>\n",
" <td>0.285000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4472</th>\n",
" <td>2010-03-29</td>\n",
" <td>12.315000</td>\n",
" <td>12.385000</td>\n",
" <td>12.145000</td>\n",
" <td>12.305000</td>\n",
" <td>9.549243</td>\n",
" <td>13718000</td>\n",
" <td>0</td>\n",
" <td>0.240000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5944</th>\n",
" <td>2016-02-02</td>\n",
" <td>60.660000</td>\n",
" <td>60.900002</td>\n",
" <td>60.180000</td>\n",
" <td>60.700001</td>\n",
" <td>51.409283</td>\n",
" <td>9407400</td>\n",
" <td>0</td>\n",
" <td>0.720002</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6839</th>\n",
" <td>2019-08-22</td>\n",
" <td>96.589996</td>\n",
" <td>96.849998</td>\n",
" <td>95.699997</td>\n",
" <td>96.489998</td>\n",
" <td>87.342232</td>\n",
" <td>5146200</td>\n",
" <td>0</td>\n",
" <td>1.150001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>1992-08-05</td>\n",
" <td>0.425781</td>\n",
" <td>0.425781</td>\n",
" <td>0.402344</td>\n",
" <td>0.410156</td>\n",
" <td>0.318300</td>\n",
" <td>9516800</td>\n",
" <td>0</td>\n",
" <td>0.023437</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3902</th>\n",
" <td>2007-12-20</td>\n",
" <td>10.075000</td>\n",
" <td>10.280000</td>\n",
" <td>10.025000</td>\n",
" <td>10.265000</td>\n",
" <td>7.966110</td>\n",
" <td>22996200</td>\n",
" <td>1</td>\n",
" <td>0.255000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1608 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"312 1993-09-21 0.746094 0.753906 0.726563 0.734375 0.569909 \n",
"6118 2016-10-10 53.529999 53.599998 53.270000 53.299999 45.457634 \n",
"1775 1999-07-08 3.132813 3.140625 3.046875 3.078125 2.388767 \n",
"6621 2018-10-09 56.830002 59.700001 56.810001 57.709999 51.257065 \n",
"4363 2009-10-20 10.390000 10.475000 10.190000 10.265000 7.966110 \n",
"... ... ... ... ... ... ... \n",
"4472 2010-03-29 12.315000 12.385000 12.145000 12.305000 9.549243 \n",
"5944 2016-02-02 60.660000 60.900002 60.180000 60.700001 51.409283 \n",
"6839 2019-08-22 96.589996 96.849998 95.699997 96.489998 87.342232 \n",
"27 1992-08-05 0.425781 0.425781 0.402344 0.410156 0.318300 \n",
"3902 2007-12-20 10.075000 10.280000 10.025000 10.265000 7.966110 \n",
"\n",
" Volume above_average_volume volatility \n",
"312 8051200 0 0.027343 \n",
"6118 7224300 0 0.329998 \n",
"1775 43104000 1 0.093750 \n",
"6621 24855700 1 2.890000 \n",
"4363 11845000 0 0.285000 \n",
"... ... ... ... \n",
"4472 13718000 0 0.240000 \n",
"5944 9407400 0 0.720002 \n",
"6839 5146200 0 1.150001 \n",
"27 9516800 0 0.023437 \n",
"3902 22996200 1 0.255000 \n",
"\n",
"[1608 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_volume</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>312</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6118</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1775</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6621</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4363</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4472</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5944</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6839</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3902</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1608 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_volume\n",
"312 0\n",
"6118 0\n",
"1775 1\n",
"6621 1\n",
"4363 0\n",
"... ...\n",
"4472 0\n",
"5944 0\n",
"6839 0\n",
"27 0\n",
"3902 1\n",
"\n",
"[1608 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"above_average_volume\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 3,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class StarbucksFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
" \n",
"\n",
"columns_to_drop = [\"Date\"]\n",
"num_columns = [\"Close\", \"Open\", \"Adj Close\", \"High\", \"Low\", \"Volume\", \"above_average_volume\"]\n",
"cat_columns = []\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Де мо нс тр а ция работы ко нве йе р а __"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 4,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2024-11-01 17:49:40 +04:00
" <th>Close</th>\n",
" <th>Open</th>\n",
" <th>Adj Close</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>7159</th>\n",
" <td>2.052122</td>\n",
" <td>2.047553</td>\n",
" <td>2.057055</td>\n",
" <td>2.035800</td>\n",
" <td>2.068394</td>\n",
" <td>-1.046507</td>\n",
" <td>-0.733850</td>\n",
" <td>0.700004</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>4505</th>\n",
" <td>-0.493609</td>\n",
" <td>-0.482248</td>\n",
" <td>-0.509368</td>\n",
" <td>-0.485819</td>\n",
" <td>-0.493841</td>\n",
" <td>0.708938</td>\n",
" <td>1.362677</td>\n",
" <td>0.575000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>421</th>\n",
" <td>-0.867869</td>\n",
" <td>-0.867429</td>\n",
" <td>-0.818396</td>\n",
" <td>-0.868235</td>\n",
" <td>-0.866632</td>\n",
" <td>-0.450983</td>\n",
" <td>-0.733850</td>\n",
" <td>0.031250</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>1595</th>\n",
" <td>-0.819432</td>\n",
" <td>-0.817932</td>\n",
" <td>-0.778575</td>\n",
" <td>-0.818012</td>\n",
" <td>-0.819050</td>\n",
" <td>0.558091</td>\n",
" <td>1.362677</td>\n",
" <td>0.148437</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>3676</th>\n",
" <td>-0.373633</td>\n",
" <td>-0.364031</td>\n",
" <td>-0.412080</td>\n",
" <td>-0.367150</td>\n",
" <td>-0.368421</td>\n",
" <td>1.153036</td>\n",
" <td>1.362677</td>\n",
" <td>0.420000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2024-11-01 17:49:40 +04:00
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>5976</th>\n",
" <td>0.890812</td>\n",
" <td>0.897589</td>\n",
" <td>0.761079</td>\n",
" <td>0.896985</td>\n",
" <td>0.899914</td>\n",
" <td>-0.027099</td>\n",
" <td>-0.733850</td>\n",
" <td>1.020001</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>1305</th>\n",
" <td>-0.808836</td>\n",
" <td>-0.812807</td>\n",
" <td>-0.769864</td>\n",
" <td>-0.809815</td>\n",
" <td>-0.811178</td>\n",
" <td>1.139386</td>\n",
" <td>1.362677</td>\n",
" <td>0.164062</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>6085</th>\n",
" <td>0.792446</td>\n",
" <td>0.786081</td>\n",
" <td>0.683373</td>\n",
" <td>0.781419</td>\n",
" <td>0.796750</td>\n",
" <td>-0.571535</td>\n",
" <td>-0.733850</td>\n",
" <td>0.540001</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>5470</th>\n",
" <td>0.216858</td>\n",
" <td>0.226603</td>\n",
" <td>0.129761</td>\n",
" <td>0.218514</td>\n",
" <td>0.222586</td>\n",
" <td>-0.303594</td>\n",
" <td>-0.733850</td>\n",
" <td>0.584999</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-01 17:49:40 +04:00
" <th>5781</th>\n",
" <td>0.681859</td>\n",
" <td>0.655790</td>\n",
" <td>0.560632</td>\n",
" <td>0.672651</td>\n",
" <td>0.666218</td>\n",
" <td>-0.556786</td>\n",
" <td>-0.733850</td>\n",
" <td>1.200001</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-11-01 17:49:40 +04:00
"<p>6428 rows × 8 columns</p>\n",
2024-10-31 15:56:19 +04:00
"</div>"
],
"text/plain": [
2024-11-01 17:49:40 +04:00
" Close Open Adj Close High Low Volume \\\n",
"7159 2.052122 2.047553 2.057055 2.035800 2.068394 -1.046507 \n",
"4505 -0.493609 -0.482248 -0.509368 -0.485819 -0.493841 0.708938 \n",
"421 -0.867869 -0.867429 -0.818396 -0.868235 -0.866632 -0.450983 \n",
"1595 -0.819432 -0.817932 -0.778575 -0.818012 -0.819050 0.558091 \n",
"3676 -0.373633 -0.364031 -0.412080 -0.367150 -0.368421 1.153036 \n",
"... ... ... ... ... ... ... \n",
"5976 0.890812 0.897589 0.761079 0.896985 0.899914 -0.027099 \n",
"1305 -0.808836 -0.812807 -0.769864 -0.809815 -0.811178 1.139386 \n",
"6085 0.792446 0.786081 0.683373 0.781419 0.796750 -0.571535 \n",
"5470 0.216858 0.226603 0.129761 0.218514 0.222586 -0.303594 \n",
"5781 0.681859 0.655790 0.560632 0.672651 0.666218 -0.556786 \n",
2024-10-31 15:56:19 +04:00
"\n",
2024-11-01 17:49:40 +04:00
" above_average_volume volatility \n",
"7159 -0.733850 0.700004 \n",
"4505 1.362677 0.575000 \n",
"421 -0.733850 0.031250 \n",
"1595 1.362677 0.148437 \n",
"3676 1.362677 0.420000 \n",
"... ... ... \n",
"5976 -0.733850 1.020001 \n",
"1305 1.362677 0.164062 \n",
"6085 -0.733850 0.540001 \n",
"5470 -0.733850 0.584999 \n",
"5781 -0.733850 1.200001 \n",
2024-10-31 15:56:19 +04:00
"\n",
2024-11-01 17:49:40 +04:00
"[6428 rows x 8 columns]"
2024-10-31 15:56:19 +04:00
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 4,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 5,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 6,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 7,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
2024-11-07 10:58:14 +04:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAQ9CAYAAACSpDaqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxU5f4H8M8Z9m3YZE1ElFLJXctIzVIEzUzTrj8NU9S0DPfrWu6alC2a5pLm2tVsMzMzklxTidyvC+KGigugIiAoMMyc3x9cpiZgYPQMZ5jzeb9e53XjPM+ceQ43zqfvec4iiKIogoiIiIiISOFUcg+AiIiIiIjIErA4IiIiIiIiAosjIiIiIiIiACyOiIiIiIiIALA4IiIiIiIiAsDiiIiIiIiICACLIyIiIiIiIgAsjoiIiIiIiACwOCIiIiIiIgLA4oge0tq1ayEIAi5fvmyW7V++fBmCIGDt2rWSbG/Pnj0QBAF79uyRZHtERETWYubMmRAEoUp9BUHAzJkzzTsgIhmxOCKrsnTpUskKKiIiIiJSFlu5B0BUnuDgYDx48AB2dnYmfW7p0qWoVasWYmJiDNY/99xzePDgAezt7SUcJRERUc03depUTJ48We5hEFkEFkdkkQRBgKOjo2TbU6lUkm6PiIjIGuTn58PFxQW2tvxPQiKAl9WRhJYuXYonn3wSDg4OCAwMRGxsLLKzs8v0W7JkCerVqwcnJyc8/fTT+P333/H888/j+eef1/cp756j9PR0DBo0CLVr14aDgwMCAgLQo0cP/X1PdevWxenTp7F3714IggBBEPTbrOieo6SkJLz44ovw9PSEi4sLmjZtik8//VTaXwwREZEFKL236MyZM3jttdfg6emJdu3alXvPUWFhIcaOHQsfHx+4ubnh5ZdfxrVr18rd7p49e9C6dWs4Ojqifv36+Pzzzyu8j+k///kPWrVqBScnJ3h5eaFv375IS0szy/4SPQyeJiBJzJw5E7NmzUJERASGDx+OlJQULFu2DIcOHcKBAwf0l8ctW7YMI0aMQPv27TF27FhcvnwZPXv2hKenJ2rXrm30O3r37o3Tp09j5MiRqFu3LjIzM5GQkICrV6+ibt26WLhwIUaOHAlXV1e8++67AAA/P78Kt5eQkICXXnoJAQEBGD16NPz9/ZGcnIxt27Zh9OjR0v1yiIiILMi//vUvPP7445g3bx5EUURmZmaZPm+88Qb+85//4LXXXsOzzz6LXbt2oVu3bmX6HTt2DF26dEFAQABmzZoFrVaL2bNnw8fHp0zf9957D9OmTUOfPn3wxhtv4NatW1i8eDGee+45HDt2DB4eHubYXSLTiEQPYc2aNSIAMTU1VczMzBTt7e3FyMhIUavV6vt89tlnIgBx9erVoiiKYmFhoejt7S0+9dRTokaj0fdbu3atCEDs0KGDfl1qaqoIQFyzZo0oiqJ49+5dEYD44YcfGh3Xk08+abCdUrt37xYBiLt37xZFURSLi4vFkJAQMTg4WLx7965BX51OV/VfBBERUQ0xY8YMEYDYr1+/cteXOn78uAhAfPvttw36vfbaayIAccaMGfp13bt3F52dncXr16/r150/f160tbU12Obly5dFGxsb8b333jPY5smTJ0VbW9sy64nkwsvq6JH99ttvKCoqwpgxY6BS/fWv1NChQ6FWq/Hzzz8DAA4fPow7d+5g6NChBtc2R0dHw9PT0+h3ODk5wd7eHnv27MHdu3cfeczHjh1DamoqxowZU+ZMVVUfZ0pERFQTvfXWW0bbt2/fDgAYNWqUwfoxY8YY/KzVavHbb7+hZ8+eCAwM1K8PDQ1F165dDfpu3rwZOp0Offr0we3bt/WLv78/Hn/8cezevfsR9ohIOrysjh7ZlStXAAANGjQwWG9vb4969erp20v/NzQ01KCfra0t6tata/Q7HBwc8MEHH+Df//43/Pz88Mwzz+Cll17CgAED4O/vb/KYL168CABo3LixyZ8lIiKqyUJCQoy2X7lyBSqVCvXr1zdY/8+cz8zMxIMHD8rkOlA268+fPw9RFPH444+X+52mPp2WyFxYHFGNMWbMGHTv3h1btmzBr7/+imnTpiEuLg67du1CixYt5B4eERFRjeDk5FTt36nT6SAIAn755RfY2NiUaXd1da32MRGVh5fV0SMLDg4GAKSkpBisLyoqQmpqqr699H8vXLhg0K+4uFj/xLnK1K9fH//+97+xY8cOnDp1CkVFRfj444/17VW9JK70bNipU6eq1J+IiEgpgoODodPp9FdZlPpnzvv6+sLR0bFMrgNls75+/foQRREhISGIiIgoszzzzDPS7wjRQ2BxRI8sIiIC9vb2WLRoEURR1K9ftWoVcnJy9E+3ad26Nby9vbFy5UoUFxfr+23YsKHS+4ju37+PgoICg3X169eHm5sbCgsL9etcXFzKfXz4P7Vs2RIhISFYuHBhmf5/3wciIiKlKb1faNGiRQbrFy5caPCzjY0NIiIisGXLFty4cUO//sKFC/jll18M+vbq1Qs2NjaYNWtWmZwVRRF37tyRcA+IHh4vq6NH5uPjgylTpmDWrFno0qULXn75ZaSkpGDp0qV46qmn0L9/fwAl9yDNnDkTI0eORMeOHdGnTx9cvnwZa9euRf369Y3O+pw7dw6dOnVCnz59EBYWBltbW/zwww/IyMhA37599f1atWqFZcuWYe7cuQgNDYWvry86duxYZnsqlQrLli1D9+7d0bx5cwwaNAgBAQE4e/YsTp8+jV9//VX6XxQREVEN0Lx5c/Tr1w9Lly5FTk4Onn32WezcubPcGaKZM2dix44daNu2LYYPHw6tVovPPvsMjRs3xvHjx/X96tevj7lz52LKlCn613i4ubkhNTUVP/zwA4YNG4bx48dX414SlY/FEUli5syZ8PHxwWeffYaxY8fCy8sLw4YNw7x58wxushwxYgREUcTHH3+M8ePHo1mzZti6dStGjRoFR0fHCrcfFBSEfv36YefOnfjyyy9ha2uLhg0b4ptvvkHv3r31/aZPn44rV65g/vz5uHfvHjp06FBucQQAUVFR2L17N2bNmoWPP/4YOp0O9evXx9ChQ6X7xRAREdVAq1evho+PDzZs2IAtW7agY8eO+PnnnxEUFGTQr1WrVvjll18wfvx4TJs2DUFBQZg9ezaSk5Nx9uxZg76TJ0/GE088gQULFmDWrFkASvI9MjISL7/8crXtG5ExgshriEhmOp0OPj4+6NWrF1auXCn3cIiIiOgR9ezZE6dPn8b58+flHgqRSXjPEVWrgoKCMtcar1+/HllZWXj++eflGRQRERE9tAcPHhj8fP78eWzfvp25TjUSZ46oWu3Zswdjx47Fv/71L3h7e+Po0aNYtWoVGjVqhCNHjsDe3l7uIRIREZEJAgICEBMTo3+34bJly1BYWIhjx45V+F4jIkvFe46oWtWtWxdBQUFYtGgRsrKy4OXlhQEDBuD9999nYURERFQDdenSBV999RXS09Ph4OCA8PBwzJs3j4UR1UicOSIiIiIiIgLvOSIiIiIiIgLA4oiIiIiIiAgA7zmqEp1Ohxs3bsDNzc3oi0qJrJEoirh37x4CAwOhUkl7PqWgoABFRUWV9rO3tzf6HiwiUh5mMykZs9l8WBxVwY0bN8q89IxIadLS0lC7dm3JtldQUICQYFekZ2or7evv74/U1FSrPAgT0cNhNhMxm82BxVEVuLm5AQCuHK0LtSuvRJTDK080kXsIilUMDfZju/7vQCpFRUVIz9TiwuEgqN0q/rvKvadDaOs0FBUVWd0BmIgeHrNZfsxm+TCbzYfFURWUTterXVVG/0Uh87EV7OQegnL973mW5rpsxdVNgKtbxdvWgZfLEFFZzGb5MZtlxGw2GxZHRCQrjaiFxsgbBTSirhpHQ0RERErOZp5qISJZ6SBWuphi37596N69OwIDAyEIArZs2WLQLooipk+fjoCAADg5OSEiIgLnz5836JOVlYXo6Gio1Wp4eHhgyJAhyMvLM+jz3//+F+3bt4ejoyOCgoIwf/78h9p/IiIiS6PkbGZxRESy0kGE1shi6gE
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
2024-10-31 15:56:19 +04:00
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1045: Это количество истинных положительных диагнозов (True Positives), где модель правильно определила объекты как \"More\".\n",
"\n",
"563: Это количество ложных отрицательных диагнозов (False Negatives), где модель неправильно определила объекты, которые на самом деле принадлежат к классу \"More\", отнесёнными к классу \"Less\".\n",
"\n",
"Исходя из значений True Positives и False Negatives, можно сказать, что модель имеет высокую точность при предсказании класса \"More\". Однако, высокий уровень ложных отрицательных результатов (563) указывает на то, что существует значительное количество примеров, которые модель пропускает. Это может означать, что в некоторых случаях она не распознаёт объекты, которые должны быть классифицированы как \"More\".\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 8,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-07 10:58:14 +04:00
"#T_be49f_row0_col0, #T_be49f_row0_col1, #T_be49f_row1_col0, #T_be49f_row1_col1, #T_be49f_row2_col0, #T_be49f_row2_col1, #T_be49f_row3_col0, #T_be49f_row3_col1, #T_be49f_row4_col0, #T_be49f_row4_col1, #T_be49f_row5_col0, #T_be49f_row5_col1, #T_be49f_row6_col0, #T_be49f_row6_col1, #T_be49f_row7_col0, #T_be49f_row7_col1 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_be49f_row0_col2, #T_be49f_row0_col3, #T_be49f_row1_col2, #T_be49f_row1_col3, #T_be49f_row2_col2, #T_be49f_row2_col3, #T_be49f_row3_col2, #T_be49f_row3_col3, #T_be49f_row4_col2, #T_be49f_row4_col3, #T_be49f_row5_col2, #T_be49f_row5_col3, #T_be49f_row6_col2, #T_be49f_row6_col3 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_be49f_row0_col4, #T_be49f_row0_col5, #T_be49f_row0_col6, #T_be49f_row0_col7, #T_be49f_row1_col4, #T_be49f_row1_col5, #T_be49f_row1_col6, #T_be49f_row1_col7, #T_be49f_row2_col4, #T_be49f_row2_col5, #T_be49f_row2_col6, #T_be49f_row2_col7, #T_be49f_row3_col4, #T_be49f_row3_col5, #T_be49f_row3_col6, #T_be49f_row3_col7, #T_be49f_row4_col4, #T_be49f_row4_col5, #T_be49f_row4_col6, #T_be49f_row4_col7, #T_be49f_row5_col4, #T_be49f_row5_col5, #T_be49f_row5_col6, #T_be49f_row5_col7, #T_be49f_row6_col4, #T_be49f_row6_col5, #T_be49f_row6_col6, #T_be49f_row6_col7 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_be49f_row7_col2, #T_be49f_row7_col3 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_be49f_row7_col4, #T_be49f_row7_col5, #T_be49f_row7_col6, #T_be49f_row7_col7 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-07 10:58:14 +04:00
"<table id=\"T_be49f\">\n",
2024-10-31 15:56:19 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_be49f_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_be49f_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_be49f_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_be49f_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_be49f_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_be49f_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_be49f_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_be49f_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_be49f_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_be49f_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_be49f_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_be49f_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_be49f_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_be49f_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_be49f_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row6\" class=\"row_heading level0 row6\" >random_forest</th>\n",
" <td id=\"T_be49f_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col5\" class=\"data row6 col5\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col6\" class=\"data row6 col6\" >1.000000</td>\n",
" <td id=\"T_be49f_row6_col7\" class=\"data row6 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_be49f_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_be49f_row7_col0\" class=\"data row7 col0\" >1.000000</td>\n",
" <td id=\"T_be49f_row7_col1\" class=\"data row7 col1\" >1.000000</td>\n",
" <td id=\"T_be49f_row7_col2\" class=\"data row7 col2\" >0.994222</td>\n",
" <td id=\"T_be49f_row7_col3\" class=\"data row7 col3\" >0.994671</td>\n",
" <td id=\"T_be49f_row7_col4\" class=\"data row7 col4\" >0.997978</td>\n",
" <td id=\"T_be49f_row7_col5\" class=\"data row7 col5\" >0.998134</td>\n",
" <td id=\"T_be49f_row7_col6\" class=\"data row7 col6\" >0.997103</td>\n",
" <td id=\"T_be49f_row7_col7\" class=\"data row7 col7\" >0.997329</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-07 10:58:14 +04:00
"<pandas.io.formats.style.Styler at 0x277bd4b9100>"
2024-10-31 15:56:19 +04:00
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 8,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В с е модели в данной выборке — логистическая регрессия, ридж-регрессия, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг, случайный лес и многослойный перцептрон (MLP) — демонстрируют идеальные значения по всем метрикам на обучающих и тестовых наборах данных. Это достигается, поскольку все модели показали значения, равные 1.0 для Precision, Recall, Accuracy и F1-меры, что указывает на то, что модель безошибочно классифицирует все примеры.\n",
"\n",
"Модель MLP, хотя и имеет немного более низкие значения Recall (0.994) и F1-на тестовом наборе (0.997) по сравнению с другими, по-прежнему остается высокоэффективной. Тем не менее, она не снижает показатели классификации до такого уровня, что может вызвать обеспокоенность, и остается на уровне, близком к идеальному."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 9,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-07 10:58:14 +04:00
"#T_b9c71_row0_col0, #T_b9c71_row0_col1, #T_b9c71_row1_col0, #T_b9c71_row1_col1, #T_b9c71_row2_col0, #T_b9c71_row2_col1, #T_b9c71_row3_col0, #T_b9c71_row3_col1, #T_b9c71_row4_col0, #T_b9c71_row4_col1, #T_b9c71_row5_col0, #T_b9c71_row5_col1, #T_b9c71_row6_col0, #T_b9c71_row6_col1 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_b9c71_row0_col2, #T_b9c71_row1_col2, #T_b9c71_row2_col2, #T_b9c71_row3_col2, #T_b9c71_row4_col2, #T_b9c71_row5_col2, #T_b9c71_row6_col2, #T_b9c71_row7_col2 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_b9c71_row0_col3, #T_b9c71_row0_col4, #T_b9c71_row1_col3, #T_b9c71_row1_col4, #T_b9c71_row2_col3, #T_b9c71_row2_col4, #T_b9c71_row3_col3, #T_b9c71_row3_col4, #T_b9c71_row4_col3, #T_b9c71_row4_col4, #T_b9c71_row5_col3, #T_b9c71_row5_col4, #T_b9c71_row6_col3, #T_b9c71_row6_col4 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_b9c71_row7_col0, #T_b9c71_row7_col1 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_b9c71_row7_col3, #T_b9c71_row7_col4 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-07 10:58:14 +04:00
"<table id=\"T_b9c71\">\n",
2024-10-31 15:56:19 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_b9c71_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_b9c71_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_b9c71_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_b9c71_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_b9c71_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_b9c71_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_b9c71_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_b9c71_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row4\" class=\"row_heading level0 row4\" >naive_bayes</th>\n",
" <td id=\"T_b9c71_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_b9c71_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row6\" class=\"row_heading level0 row6\" >random_forest</th>\n",
" <td id=\"T_b9c71_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_b9c71_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_b9c71_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_b9c71_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_b9c71_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_b9c71_row7_col0\" class=\"data row7 col0\" >0.998134</td>\n",
" <td id=\"T_b9c71_row7_col1\" class=\"data row7 col1\" >0.997329</td>\n",
" <td id=\"T_b9c71_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_b9c71_row7_col3\" class=\"data row7 col3\" >0.995895</td>\n",
" <td id=\"T_b9c71_row7_col4\" class=\"data row7 col4\" >0.995904</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-07 10:58:14 +04:00
"<pandas.io.formats.style.Styler at 0x277bd51d8e0>"
2024-10-31 15:56:19 +04:00
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 9,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В с е модели, включая логистическую регрессию, ридж-регрессию, дерево решений, KNN, наивный байесовский классификатор, градиентный бустинг и случайный лес, продемонстрировали идеальные значения по всем метрикам: Accuracy, F1, ROC AUC, Cohen's Kappa и MCC, достигнув максимальных значений, равных 1. Это подчеркивает их эффективность в контексте анализа и классификации данных.\n",
"\n",
"Модель MLP, хотя и показала очень высокие результаты, несколько уступает конкурентам по показателям Accuracy (0.998) и F1 (0.997). Несмотря на это, она достигает оптимального значения ROC AUC (1.000), что указывает на е е способность к выделению классов. Показатели Cohen's Kappa (0.996) и MCC (0.996) также находятся на высоком уровне, что говорит о хорошей согласованности и строгости классификации."
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 10,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 11,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Predicted</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Date, Predicted, Open, High, Low, Close, Adj Close, Volume, above_average_volume, volatility]\n",
"Index: []"
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 11,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"above_average_volume\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 12,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6621</th>\n",
" <td>2018-10-09</td>\n",
" <td>56.830002</td>\n",
" <td>59.700001</td>\n",
" <td>56.810001</td>\n",
" <td>57.709999</td>\n",
" <td>51.257065</td>\n",
" <td>24855700</td>\n",
" <td>1</td>\n",
" <td>2.89</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"6621 2018-10-09 56.830002 59.700001 56.810001 57.709999 51.257065 \n",
"\n",
" Volume above_average_volume volatility \n",
"6621 24855700 1 2.89 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Close</th>\n",
" <th>Open</th>\n",
" <th>Adj Close</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Volume</th>\n",
" <th>above_average_volume</th>\n",
" <th>volatility</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6621</th>\n",
" <td>0.831494</td>\n",
" <td>0.805759</td>\n",
" <td>0.783016</td>\n",
" <td>0.874818</td>\n",
" <td>0.821113</td>\n",
" <td>0.857847</td>\n",
" <td>1.362677</td>\n",
" <td>2.89</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Close Open Adj Close High Low Volume \\\n",
"6621 0.831494 0.805759 0.783016 0.874818 0.821113 0.857847 \n",
"\n",
" above_average_volume volatility \n",
"6621 1.362677 2.89 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [9.31850788e-04 9.99068149e-01])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 6621\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 13,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\a3012\\AIM-PIbd-31-Zhirnova-A-E\\aimenv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 13,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__О б у че ние модели с новыми г ипе р па р а ме тр а ми__"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
2024-11-07 10:58:14 +04:00
"outputs": [],
2024-10-31 15:56:19 +04:00
"source": [
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import numpy as np\n",
"from sklearn import metrics\n",
2024-11-07 10:58:14 +04:00
"import pandas as pd\n",
"\n",
"\n",
"# Определяем числовые признаки\n",
"numeric_features = X_train.select_dtypes(include=['float64', 'int64']).columns.tolist()\n",
"\n",
"# Установка random_state\n",
"random_state = 42\n",
2024-10-31 15:56:19 +04:00
"\n",
2024-11-07 10:58:14 +04:00
"# Определение трансформера\n",
2024-10-31 15:56:19 +04:00
"pipeline_end = ColumnTransformer([\n",
2024-11-07 10:58:14 +04:00
" ('numeric', StandardScaler(), numeric_features),\n",
2024-10-31 15:56:19 +04:00
" # Добавьте другие трансформеры, если требуется\n",
"])\n",
"\n",
"# Объявление модели\n",
"optimized_model = RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"# Создание пайплайна с корректными шагами\n",
"result = {}\n",
"\n",
2024-11-07 10:58:14 +04:00
"# Обучение модели\n",
2024-10-31 15:56:19 +04:00
"result[\"pipeline\"] = Pipeline([\n",
" (\"pipeline\", pipeline_end),\n",
" (\"model\", optimized_model)\n",
"]).fit(X_train, y_train.values.ravel())\n",
"\n",
"# Прогнозирование и расчет метрик\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"# Метрики для оценки модели\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 17,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 18,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-07 10:58:14 +04:00
"#T_6e790_row0_col0, #T_6e790_row0_col1, #T_6e790_row0_col2, #T_6e790_row0_col3, #T_6e790_row1_col0, #T_6e790_row1_col1, #T_6e790_row1_col2, #T_6e790_row1_col3 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_6e790_row0_col4, #T_6e790_row0_col5, #T_6e790_row0_col6, #T_6e790_row0_col7, #T_6e790_row1_col4, #T_6e790_row1_col5, #T_6e790_row1_col6, #T_6e790_row1_col7 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-07 10:58:14 +04:00
"<table id=\"T_6e790\">\n",
2024-10-31 15:56:19 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_6e790_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_6e790_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_6e790_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_6e790_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_6e790_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_6e790_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_6e790_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_6e790_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" <th class=\"blank col5\" > </th>\n",
" <th class=\"blank col6\" > </th>\n",
" <th class=\"blank col7\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_6e790_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_6e790_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_6e790_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6e790_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_6e790_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_6e790_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-07 10:58:14 +04:00
"<pandas.io.formats.style.Styler at 0x2779cfe9880>"
2024-10-31 15:56:19 +04:00
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 18,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"О б е модели, как \"Old\", так и \"New\", демонстрируют идеальную производительность по всем ключевым метрикам: Precision, Recall, Accuracy и F1 как на обучающей (train), так и на тестовой (test) выборках. В с е значения равны 1.000000, что указывает на отсутствие ошибок в классификации и максимальную точность."
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 19,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-07 10:58:14 +04:00
"#T_376e9_row0_col0, #T_376e9_row0_col1, #T_376e9_row1_col0, #T_376e9_row1_col1 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-07 10:58:14 +04:00
"#T_376e9_row0_col2, #T_376e9_row0_col3, #T_376e9_row0_col4, #T_376e9_row1_col2, #T_376e9_row1_col3, #T_376e9_row1_col4 {\n",
2024-10-31 15:56:19 +04:00
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-07 10:58:14 +04:00
"<table id=\"T_376e9\">\n",
2024-10-31 15:56:19 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_376e9_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_376e9_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_376e9_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_376e9_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_376e9_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_376e9_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_376e9_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_376e9_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_376e9_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_376e9_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_376e9_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" <tr>\n",
2024-11-07 10:58:14 +04:00
" <th id=\"T_376e9_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_376e9_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_376e9_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_376e9_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_376e9_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_376e9_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
2024-10-31 15:56:19 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-07 10:58:14 +04:00
"<pandas.io.formats.style.Styler at 0x2779cfe8b60>"
2024-10-31 15:56:19 +04:00
]
},
2024-11-07 10:58:14 +04:00
"execution_count": 19,
2024-10-31 15:56:19 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"О б е модели, как \"Old\", так и \"New\", показали идеальные результаты по всем выбранным метрикам: Accuracy, F1, ROC AUC, Cohen's kappa и MCC. В с е метрики имеют значение 1.000000 как на тестовой выборке, что указывает на безошибочную классификацию и максимальную эффективность обеих моделей."
]
},
{
"cell_type": "code",
2024-11-07 10:58:14 +04:00
"execution_count": 20,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
2024-11-07 10:58:14 +04:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2kAAAGsCAYAAABHMu+IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJ10lEQVR4nO3de1xUdf7H8fcAchEZEBUQQ/JSKmVeNyMttyLR7OLqbj9bLUxXdwszNTXbylJLys1ydU0r8/Zb/VnbxS0ry2zVVDK1LDO8W9oFtAgQjOvM7w/WqUmxM3KGOTO8no/HeWycc/jOd1gevP2c7+ecsTmdTqcAAAAAAJYQ5OsJAAAAAAB+QpEGAAAAABZCkQYAAAAAFkKRBgAAAAAWQpEGAAAAABZCkQYAAAAAFkKRBgAAAAAWQpEGAAAAABYS4usJAADOrLS0VOXl5aaNFxoaqvDwcNPGAwDAE+SacRRpAGBBpaWlapXcSLnHqkwbMyEhQYcPHw7YQAMAWBe55hmKNACwoPLycuUeq9LhHcmyR9W+M73ohEOtun2p8vLygAwzAIC1kWueoUgDAAuzRwWZEmYAAFgBuWYMRRoAWFiV06EqpznjAADga+SaMRRpAGBhDjnlUO3TzIwxAACoLXLNGNYaAQAAAMBCWEkDAAtzyCEzGjrMGQUAgNoh14yhSAMAC6tyOlXlrH1LhxljAABQW+SaMbQ7AgAAAICFsJIGABbGDdYAgEBCrhlDkQYAFuaQU1WEGQAgQJBrxtDuCAAAAAAWwkoaAFgYbSEAgEBCrhnDShoA4DQbN27UDTfcoMTERNlsNq1atcrtuNPp1JQpU9S8eXNFREQoLS1N+/fvdzsnPz9fQ4YMkd1uV0xMjEaMGKHi4mK3cz799FNdccUVCg8PV1JSkmbOnOnttwYAqKf8Kdso0gDAwk49qtiMzRMlJSXq1KmT5s2bd8bjM2fO1Jw5c7RgwQJt3bpVkZGRSk9PV2lpqeucIUOGaPfu3Vq7dq1Wr16tjRs3atSoUa7jRUVF6tOnj5KTk7Vjxw797W9/08MPP6xnn3323H5YAADL81WuSf6VbTanM8A/ZAAA/FBRUZGio6O1JydeUVG1v5524oRD7TvkqbCwUHa73aPvtdlsevXVVzVgwABJ1VcaExMTdc8992jChAmSpMLCQsXHx2vJkiUaPHiwcnJylJKSom3btql79+6SpDVr1ui6667TV199pcTERM2fP1/333+/cnNzFRoaKkmaPHmyVq1apT179tT6PQMArMNKuSZZP9tYSQOAeqSoqMhtKysr83iMw4cPKzc3V2lpaa590dHR6tGjh7KzsyVJ2dnZiomJcYWYJKWlpSkoKEhbt251nXPllVe6QkyS0tPTtXfvXv3www/n+hYBAPWIGbkmWS/bKNIAwMKq/vuoYjM2SUpKSlJ0dLRry8rK8nhOubm5kqT4+Hi3/fHx8a5jubm5iouLczseEhKi2NhYt3PONMbPXwMAEFismGuS9bKNpzsCgIVVOas3M8aRpKNHj7q1hYSFhdV+cAAADCLXjGElDQDqEbvd7radS5glJCRIkvLy8tz25+XluY4lJCTo2LFjbscrKyuVn5/vds6Zxvj5awAAcDZm5JpkvWyjSAMAC3OYuJmlVatWSkhI0Lp161z7ioqKtHXrVqWmpkqSUlNTVVBQoB07drjOee+99+RwONSjRw/XORs3blRFRYXrnLVr16pdu3Zq3LixiTMGAFiFFXNNsl62UaQBgIU5ZFOVCZtDNo9et7i4WDt37tTOnTslVd9QvXPnTh05ckQ2m01jx47VI488otdee027du3SbbfdpsTERNdTsjp06KC+fftq5MiR+vDDD7V582aNHj1agwcPVmJioiTpj3/8o0JDQzVixAjt3r1bL7zwgv7+979r/PjxZv4IAQAW4qtck/wr27gnDQBwmu3bt+uqq65yfX0qXDIyMrRkyRJNmjRJJSUlGjVqlAoKCtSrVy+tWbNG4eHhru9Zvny5Ro8erWuuuUZBQUEaNGiQ5syZ4zoeHR2td955R5mZmerWrZuaNm2qKVOmuH3eDAAAZvGnbONz0gDAgk59nsz23fFqZMLnyRSfcKj7Ref+eTIAANQGueYZ2h0BAAAAwEJodwQACzvVe2/GOAAA+Bq5ZgxFGgBYGGEGAAgk5JoxtDsCAAAAgIWwkgYAFuZw2uRw1v5qoRljAABQW+SaMRRpAGBhtIUAAAIJuWYM7Y4AAAAAYCGspAGAhVUpSFUmXE+rMmEuAADUFrlmDEUaAFiY06TefWeA9+4DAPwDuWYM7Y4AAAAAYCGspAGAhXGDNQAgkJBrxlCkAYCFVTmDVOU0oXffacJkAACoJXLNGNodAQAAAMBCWEkDAAtzyCaHCdfTHArwS44AAL9ArhnDShoAAAAAWAgraQBgYdxgDQAIJOSaMRRpAGBh5t1gHdhtIQAA/0CuGUO7IwAAAABYCCtpAGBh1TdY176lw4wxAACoLXLNGIo0ALAwh4JUxVOwAAABglwzhnZHAAAAALAQVtIAwMK4wRoAEEjINWMo0gDAwhwK4kM/AQABg1wzhnZHAAAAALAQVtIAwMKqnDZVOU340E8TxgAAoLbINWNYSQMAAAAAC2ElDQAsrMqkRxVXBXjvPgDAP5BrxlCkAYCFOZxBcpjwFCxHgD8FCwDgH8g1Y2h3BAAAAAALYSUNACyMthAAQCAh14yhSAMAC3PInCdYOWo/FQAAao1cM4Z2RwAAAACwEFbSAMDCHAqSw4TraWaMAQBAbZFrxlCkAYCFVTmDVGXCU7DMGAMAgNoi14wJ7HcHAAAAAH6GlTQAsDCHbHLIjBusaz8GAAC1Ra4ZQ5EGABZGWwgAIJCQa8YE9rsDAAAAAD/DShoAWJh5H/rJNTkAgO+Ra8YE9rsDAAAAAD/DSpoBDodD33zzjaKiomSzBfZNigBqz+l06sSJE0pMTFRQUO2uhTmcNjmcJtxgbcIYCBzkGgBPkGt1jyLNgG+++UZJSUm+ngYAP3P06FGdd955tRrDYVJbSKB/6Cc8Q64BOBfkWt2hSDMgKipKkvTlR+fL3iiwfyHgud9d2NHXU4DFVKpCm/Sm628HYDXkGs6GXMMvkWt1jyLNgFOtIPZGQbJHEWZwF2Jr4OspwGqc1f9jRhuZwxkkhwmPGTZjDAQOcg1nQ67hNORanaNIAwALq5JNVSZ8YKcZYwAAUFvkmjGBXYICAAAAgJ9hJQ0ALIy2EABAICHXjKFIAwALq5I5LR1VtZ8KAAC1Rq4ZE9glKAAAAAD4GVbSAMDCaAsBAAQScs2YwH53AAAAAOBnWEkDAAurcgapyoSrhWaMAQBAbZFrxlCkAYCFOWWTw4QbrJ0B/nkyAAD/QK4ZE9glKAAAAAD4GVbSAMDCaAsBAAQScs0YijQAsDCH0yaHs/YtHWaMAQBAbZFrxgR2CQoAAAAAfoaVNACwsCoFqcqE62lmjAEAQG2Ra8ZQpAGAhdEWAgAIJOSaMYFdggIAAACAn2ElDQAszKEgOUy4nmbGGAAA1Ba5ZgxFGgBYWJXTpioTWjrMGAMAgNoi14wJ7BIUAAAAAPwMK2kAYGHcYA0ACCTkmjGspAEAAACAhVCkAYCFOZ1BcpiwOZ2e/bmvqqrSgw8+qFatWikiIkJt2rTR9OnT5XQ6fzY3p6ZMmaLmzZsrIiJCaWlp2r9/v9s4+fn5GjJkiOx2u2JiYjRixAgVFxeb8rMBAPgfcs0YijQAsLAq2UzbPPH4449r/vz5+sc//qGcnBw9/vjjmjlzpubOnes6Z+bMmZozZ44WLFigrVu3KjIyUunp6SotLXWdM2TIEO3evVtr167V6tWrtXHjRo0aNcq0nw8AwL+Qa8ZwTxoA4DRbtmzRTTfdpP79+0uSzj//fP3f//2fPvzwQ0nVVxtnz56tBx54QDfddJMkadmyZYqPj9eqVas0ePBg5eTkaM2aNdq2bZu6d+8uSZo7d66uu+46PfHEE0pMTPTNmwMA1Dv+lmuspAGAhTmcP91kXbuteryioiK3rays7Iyve/nll2vdunXat2+fJOm
2024-10-31 15:56:19 +04:00
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтом квадрате мы видим значение 1049, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Less\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В зеленом квадрате значение 558 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это также является показателем высокой точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-01 17:49:40 +04:00
"## Определение достижимого уровня качества модели для второй задачи (добавляю конвейер для решения задачи регрессии)"
2024-10-31 15:56:19 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__По дг о то вка да нных __"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Загрузка данных и создание целевой переменной"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'Close': 30.058856538825285\n",
" Date Open High Low Close Adj Close Volume \\\n",
"0 1992-06-26 0.328125 0.347656 0.320313 0.335938 0.260703 224358400 \n",
"1 1992-06-29 0.339844 0.367188 0.332031 0.359375 0.278891 58732800 \n",
"2 1992-06-30 0.367188 0.371094 0.343750 0.347656 0.269797 34777600 \n",
"3 1992-07-01 0.351563 0.359375 0.339844 0.355469 0.275860 18316800 \n",
"4 1992-07-02 0.359375 0.359375 0.347656 0.355469 0.275860 13996800 \n",
"\n",
" above_average_close Close_Next_Day \n",
"0 0 0.359375 \n",
"1 0 0.347656 \n",
"2 0 0.355469 \n",
"3 0 0.355469 \n",
"4 0 0.355469 \n",
"Статистическое описание DataFrame:\n",
" Open High Low Close Adj Close \\\n",
"count 8035.000000 8035.000000 8035.000000 8035.000000 8035.000000 \n",
"mean 30.048051 30.345221 29.745172 30.052733 26.667480 \n",
"std 33.613031 33.904070 33.312079 33.613521 31.724640 \n",
"min 0.328125 0.347656 0.320313 0.335938 0.260703 \n",
"25% 4.391563 4.531250 4.304844 4.399219 3.413997 \n",
"50% 13.325000 13.485000 13.150000 13.330000 10.352452 \n",
"75% 55.250000 55.715000 54.829999 55.254999 47.461098 \n",
"max 126.080002 126.320000 124.809998 126.059998 118.010414 \n",
"\n",
" Volume above_average_close Close_Next_Day \n",
"count 8.035000e+03 8035.000000 8035.000000 \n",
"mean 1.470584e+07 0.347480 30.062556 \n",
"std 1.340058e+07 0.476199 33.616368 \n",
"min 1.504000e+06 0.000000 0.347656 \n",
"25% 7.818550e+06 0.000000 4.403125 \n",
"50% 1.170240e+07 0.000000 13.330000 \n",
"75% 1.778850e+07 1.000000 55.274999 \n",
"max 5.855088e+08 1.000000 126.059998 \n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Загрузка данных о ценах акций Starbucks из CSV файла\n",
"df = pd.read_csv(\".//static//csv//Starbucks Dataset.csv\")\n",
"\n",
"# Опция для настройки генерации случайных чисел (если это нужно для других частей кода)\n",
"random_state = 42\n",
"\n",
"# Вычисление среднего значения поля \"Close\"\n",
"average_close = df['Close'].mean()\n",
"print(f\"Среднее значение поля 'Close': {average_close}\")\n",
"\n",
"# Создание новой колонки, указывающей, выше или ниже среднего значение цена закрытия\n",
"df['above_average_close'] = (df['Close'] > average_close).astype(int)\n",
"\n",
"# Создание целевой переменной для прогнозирования (цена закрытия на следующий день)\n",
"df['Close_Next_Day'] = df['Close'].shift(-1)\n",
"\n",
"# Удаление последней строки, где нет значения для следующего дня\n",
"df.dropna(inplace=True)\n",
"\n",
"# Вывод DataFrame с новой колонкой\n",
"print(df.head())\n",
"\n",
"# Примерный анализ данных\n",
"print(\"Статистическое описание DataFrame:\")\n",
"print(df.describe())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n",
"\n",
"Целевой признак -- above_average_close"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2484</th>\n",
" <td>2002-05-06</td>\n",
" <td>5.867500</td>\n",
" <td>5.897500</td>\n",
" <td>5.637500</td>\n",
" <td>5.665000</td>\n",
" <td>4.396299</td>\n",
" <td>10545200</td>\n",
" <td>0</td>\n",
" <td>5.700000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1576</th>\n",
" <td>1998-09-22</td>\n",
" <td>1.882813</td>\n",
" <td>1.925781</td>\n",
" <td>1.867188</td>\n",
" <td>1.902344</td>\n",
" <td>1.476306</td>\n",
" <td>42080000</td>\n",
" <td>0</td>\n",
" <td>2.058594</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6595</th>\n",
" <td>2018-08-31</td>\n",
" <td>52.459999</td>\n",
" <td>53.709999</td>\n",
" <td>52.450001</td>\n",
" <td>53.450001</td>\n",
" <td>47.473415</td>\n",
" <td>10892800</td>\n",
" <td>1</td>\n",
" <td>53.529999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7412</th>\n",
" <td>2021-11-30</td>\n",
" <td>109.550003</td>\n",
" <td>111.089996</td>\n",
" <td>109.050003</td>\n",
" <td>109.639999</td>\n",
" <td>103.481560</td>\n",
" <td>9483300</td>\n",
" <td>1</td>\n",
" <td>108.660004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7413</th>\n",
" <td>2021-12-01</td>\n",
" <td>110.959999</td>\n",
" <td>113.349998</td>\n",
" <td>108.550003</td>\n",
" <td>108.660004</td>\n",
" <td>102.556618</td>\n",
" <td>7618500</td>\n",
" <td>1</td>\n",
" <td>111.419998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5519</th>\n",
" <td>2014-05-27</td>\n",
" <td>36.320000</td>\n",
" <td>36.889999</td>\n",
" <td>36.270000</td>\n",
" <td>36.830002</td>\n",
" <td>30.466820</td>\n",
" <td>10100400</td>\n",
" <td>1</td>\n",
" <td>36.634998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4531</th>\n",
" <td>2010-06-22</td>\n",
" <td>14.035000</td>\n",
" <td>14.240000</td>\n",
" <td>13.575000</td>\n",
" <td>13.615000</td>\n",
" <td>10.609633</td>\n",
" <td>20533200</td>\n",
" <td>0</td>\n",
" <td>13.660000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>535</th>\n",
" <td>1994-08-09</td>\n",
" <td>0.906250</td>\n",
" <td>0.921875</td>\n",
" <td>0.890625</td>\n",
" <td>0.898438</td>\n",
" <td>0.697229</td>\n",
" <td>7795200</td>\n",
" <td>0</td>\n",
" <td>0.906250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>787</th>\n",
" <td>1995-08-08</td>\n",
" <td>1.183594</td>\n",
" <td>1.199219</td>\n",
" <td>1.175781</td>\n",
" <td>1.183594</td>\n",
" <td>0.918523</td>\n",
" <td>10848000</td>\n",
" <td>0</td>\n",
" <td>1.187500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7987</th>\n",
" <td>2024-03-15</td>\n",
" <td>91.599998</td>\n",
" <td>92.019997</td>\n",
" <td>90.099998</td>\n",
" <td>90.120003</td>\n",
" <td>89.441422</td>\n",
" <td>18133600</td>\n",
" <td>1</td>\n",
" <td>91.010002</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"2484 2002-05-06 5.867500 5.897500 5.637500 5.665000 4.396299 \n",
"1576 1998-09-22 1.882813 1.925781 1.867188 1.902344 1.476306 \n",
"6595 2018-08-31 52.459999 53.709999 52.450001 53.450001 47.473415 \n",
"7412 2021-11-30 109.550003 111.089996 109.050003 109.639999 103.481560 \n",
"7413 2021-12-01 110.959999 113.349998 108.550003 108.660004 102.556618 \n",
"... ... ... ... ... ... ... \n",
"5519 2014-05-27 36.320000 36.889999 36.270000 36.830002 30.466820 \n",
"4531 2010-06-22 14.035000 14.240000 13.575000 13.615000 10.609633 \n",
"535 1994-08-09 0.906250 0.921875 0.890625 0.898438 0.697229 \n",
"787 1995-08-08 1.183594 1.199219 1.175781 1.183594 0.918523 \n",
"7987 2024-03-15 91.599998 92.019997 90.099998 90.120003 89.441422 \n",
"\n",
" Volume above_average_close Close_Next_Day \n",
"2484 10545200 0 5.700000 \n",
"1576 42080000 0 2.058594 \n",
"6595 10892800 1 53.529999 \n",
"7412 9483300 1 108.660004 \n",
"7413 7618500 1 111.419998 \n",
"... ... ... ... \n",
"5519 10100400 1 36.634998 \n",
"4531 20533200 0 13.660000 \n",
"535 7795200 0 0.906250 \n",
"787 10848000 0 1.187500 \n",
"7987 18133600 1 91.010002 \n",
"\n",
"[6428 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_close</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2484</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1576</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6595</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7412</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7413</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5519</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4531</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>535</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>787</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7987</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_close\n",
"2484 0\n",
"1576 0\n",
"6595 1\n",
"7412 1\n",
"7413 1\n",
"... ...\n",
"5519 1\n",
"4531 0\n",
"535 0\n",
"787 0\n",
"7987 1\n",
"\n",
"[6428 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5022</th>\n",
" <td>2012-06-01</td>\n",
" <td>26.555000</td>\n",
" <td>27.030001</td>\n",
" <td>26.02000</td>\n",
" <td>26.075001</td>\n",
" <td>20.960617</td>\n",
" <td>17456400</td>\n",
" <td>0</td>\n",
" <td>26.950001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3110</th>\n",
" <td>2004-10-28</td>\n",
" <td>12.895000</td>\n",
" <td>13.212500</td>\n",
" <td>12.77750</td>\n",
" <td>13.212500</td>\n",
" <td>10.253506</td>\n",
" <td>12049600</td>\n",
" <td>0</td>\n",
" <td>13.220000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2931</th>\n",
" <td>2004-02-12</td>\n",
" <td>9.317500</td>\n",
" <td>9.325000</td>\n",
" <td>9.20500</td>\n",
" <td>9.245000</td>\n",
" <td>7.174544</td>\n",
" <td>8623600</td>\n",
" <td>0</td>\n",
" <td>9.175000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>2019-09-26</td>\n",
" <td>90.839996</td>\n",
" <td>91.150002</td>\n",
" <td>89.50000</td>\n",
" <td>89.800003</td>\n",
" <td>81.286491</td>\n",
" <td>5026400</td>\n",
" <td>1</td>\n",
" <td>88.370003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5147</th>\n",
" <td>2012-11-30</td>\n",
" <td>25.709999</td>\n",
" <td>26.004999</td>\n",
" <td>25.52000</td>\n",
" <td>25.934999</td>\n",
" <td>21.016182</td>\n",
" <td>11997400</td>\n",
" <td>0</td>\n",
" <td>25.895000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2947</th>\n",
" <td>2004-03-08</td>\n",
" <td>9.477500</td>\n",
" <td>9.585000</td>\n",
" <td>9.34250</td>\n",
" <td>9.365000</td>\n",
" <td>7.267669</td>\n",
" <td>14322400</td>\n",
" <td>0</td>\n",
" <td>9.382500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>784</th>\n",
" <td>1995-08-03</td>\n",
" <td>1.230469</td>\n",
" <td>1.230469</td>\n",
" <td>1.18750</td>\n",
" <td>1.203125</td>\n",
" <td>0.933680</td>\n",
" <td>13270400</td>\n",
" <td>0</td>\n",
" <td>1.195313</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4164</th>\n",
" <td>2009-01-06</td>\n",
" <td>5.025000</td>\n",
" <td>5.180000</td>\n",
" <td>4.97500</td>\n",
" <td>5.110000</td>\n",
" <td>3.965594</td>\n",
" <td>17609800</td>\n",
" <td>0</td>\n",
" <td>4.995000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>455</th>\n",
" <td>1994-04-14</td>\n",
" <td>0.804688</td>\n",
" <td>0.828125</td>\n",
" <td>0.78125</td>\n",
" <td>0.804688</td>\n",
" <td>0.624475</td>\n",
" <td>5990400</td>\n",
" <td>0</td>\n",
" <td>0.785156</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3335</th>\n",
" <td>2005-09-20</td>\n",
" <td>11.625000</td>\n",
" <td>11.775000</td>\n",
" <td>11.50250</td>\n",
" <td>11.540000</td>\n",
" <td>8.955570</td>\n",
" <td>13312000</td>\n",
" <td>0</td>\n",
" <td>11.667500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1607 rows × 9 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"5022 2012-06-01 26.555000 27.030001 26.02000 26.075001 20.960617 \n",
"3110 2004-10-28 12.895000 13.212500 12.77750 13.212500 10.253506 \n",
"2931 2004-02-12 9.317500 9.325000 9.20500 9.245000 7.174544 \n",
"6863 2019-09-26 90.839996 91.150002 89.50000 89.800003 81.286491 \n",
"5147 2012-11-30 25.709999 26.004999 25.52000 25.934999 21.016182 \n",
"... ... ... ... ... ... ... \n",
"2947 2004-03-08 9.477500 9.585000 9.34250 9.365000 7.267669 \n",
"784 1995-08-03 1.230469 1.230469 1.18750 1.203125 0.933680 \n",
"4164 2009-01-06 5.025000 5.180000 4.97500 5.110000 3.965594 \n",
"455 1994-04-14 0.804688 0.828125 0.78125 0.804688 0.624475 \n",
"3335 2005-09-20 11.625000 11.775000 11.50250 11.540000 8.955570 \n",
"\n",
" Volume above_average_close Close_Next_Day \n",
"5022 17456400 0 26.950001 \n",
"3110 12049600 0 13.220000 \n",
"2931 8623600 0 9.175000 \n",
"6863 5026400 1 88.370003 \n",
"5147 11997400 0 25.895000 \n",
"... ... ... ... \n",
"2947 14322400 0 9.382500 \n",
"784 13270400 0 1.195313 \n",
"4164 17609800 0 4.995000 \n",
"455 5990400 0 0.785156 \n",
"3335 13312000 0 11.667500 \n",
"\n",
"[1607 rows x 9 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_close</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5022</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3110</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2931</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5147</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2947</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>784</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4164</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>455</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3335</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1607 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_close\n",
"5022 0\n",
"3110 0\n",
"2931 0\n",
"6863 1\n",
"5147 0\n",
"... ...\n",
"2947 0\n",
"784 0\n",
"4164 0\n",
"455 0\n",
"3335 0\n",
"\n",
"[1607 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input: DataFrame,\n",
" stratify_colname: str = \"y\",\n",
" frac_train: float = 0.6,\n",
" frac_val: float = 0.15,\n",
" frac_test: float = 0.25,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
"\n",
" if not (0 < frac_train < 1) or not (0 <= frac_val <= 1) or not (0 <= frac_test <= 1):\n",
" raise ValueError(\"Fractions must be between 0 and 1 and the sum must equal 1.\")\n",
" \n",
" if not (frac_train + frac_val + frac_test == 1.0):\n",
" raise ValueError(\"fractions %f, %f, %f do not add up to 1.0\" %\n",
" (frac_train, frac_val, frac_test))\n",
"\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(f\"{stratify_colname} is not a column in the DataFrame.\")\n",
"\n",
" X = df_input\n",
" y = df_input[[stratify_colname]]\n",
"\n",
" \n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val == 0:\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
"\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
"\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" \n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"above_average_close\", frac_train=0.80, frac_val=0.0, frac_test=0.20, random_state=random_state\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование конвейера для классификации данных\n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"\n",
"class StarbucksFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
" \n",
"\n",
"columns_to_drop = [\"Date\"]\n",
"num_columns = [\"Close\", \"Open\", \"Adj Close\", \"High\", \"Low\", \"Volume\", \"above_average_close\"]\n",
"cat_columns = []\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_cat\", preprocessing_cat, [\"Cabin_type\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Де мо нс тр а ция работы ко нве йе р а __"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Close</th>\n",
" <th>Open</th>\n",
" <th>Adj Close</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2484</th>\n",
" <td>-0.723400</td>\n",
" <td>-0.717267</td>\n",
" <td>-0.700283</td>\n",
" <td>-0.718936</td>\n",
" <td>-0.721563</td>\n",
" <td>-0.304340</td>\n",
" <td>-0.729840</td>\n",
" <td>5.700000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1576</th>\n",
" <td>-0.835023</td>\n",
" <td>-0.835490</td>\n",
" <td>-0.792049</td>\n",
" <td>-0.835755</td>\n",
" <td>-0.834432</td>\n",
" <td>1.970579</td>\n",
" <td>-0.729840</td>\n",
" <td>2.058594</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6595</th>\n",
" <td>0.694202</td>\n",
" <td>0.665106</td>\n",
" <td>0.653502</td>\n",
" <td>0.687359</td>\n",
" <td>0.679824</td>\n",
" <td>-0.279264</td>\n",
" <td>1.370164</td>\n",
" <td>53.529999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7412</th>\n",
" <td>2.361148</td>\n",
" <td>2.358932</td>\n",
" <td>2.413670</td>\n",
" <td>2.375059</td>\n",
" <td>2.374211</td>\n",
" <td>-0.380946</td>\n",
" <td>1.370164</td>\n",
" <td>108.660004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7413</th>\n",
" <td>2.332076</td>\n",
" <td>2.400766</td>\n",
" <td>2.384602</td>\n",
" <td>2.441531</td>\n",
" <td>2.359243</td>\n",
" <td>-0.515472</td>\n",
" <td>1.370164</td>\n",
" <td>111.419998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5519</th>\n",
" <td>0.201149</td>\n",
" <td>0.186241</td>\n",
" <td>0.119036</td>\n",
" <td>0.192637</td>\n",
" <td>0.195457</td>\n",
" <td>-0.336428</td>\n",
" <td>1.370164</td>\n",
" <td>36.634998</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4531</th>\n",
" <td>-0.487553</td>\n",
" <td>-0.474942</td>\n",
" <td>-0.505016</td>\n",
" <td>-0.473560</td>\n",
" <td>-0.483945</td>\n",
" <td>0.416194</td>\n",
" <td>-0.729840</td>\n",
" <td>13.660000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>535</th>\n",
" <td>-0.864806</td>\n",
" <td>-0.864464</td>\n",
" <td>-0.816533</td>\n",
" <td>-0.865282</td>\n",
" <td>-0.863666</td>\n",
" <td>-0.502725</td>\n",
" <td>-0.729840</td>\n",
" <td>0.906250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>787</th>\n",
" <td>-0.856346</td>\n",
" <td>-0.856235</td>\n",
" <td>-0.809579</td>\n",
" <td>-0.857125</td>\n",
" <td>-0.855130</td>\n",
" <td>-0.282496</td>\n",
" <td>-0.729840</td>\n",
" <td>1.187500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7987</th>\n",
" <td>1.782063</td>\n",
" <td>1.826366</td>\n",
" <td>1.972431</td>\n",
" <td>1.814159</td>\n",
" <td>1.806921</td>\n",
" <td>0.243087</td>\n",
" <td>1.370164</td>\n",
" <td>91.010002</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Close Open Adj Close High Low Volume \\\n",
"2484 -0.723400 -0.717267 -0.700283 -0.718936 -0.721563 -0.304340 \n",
"1576 -0.835023 -0.835490 -0.792049 -0.835755 -0.834432 1.970579 \n",
"6595 0.694202 0.665106 0.653502 0.687359 0.679824 -0.279264 \n",
"7412 2.361148 2.358932 2.413670 2.375059 2.374211 -0.380946 \n",
"7413 2.332076 2.400766 2.384602 2.441531 2.359243 -0.515472 \n",
"... ... ... ... ... ... ... \n",
"5519 0.201149 0.186241 0.119036 0.192637 0.195457 -0.336428 \n",
"4531 -0.487553 -0.474942 -0.505016 -0.473560 -0.483945 0.416194 \n",
"535 -0.864806 -0.864464 -0.816533 -0.865282 -0.863666 -0.502725 \n",
"787 -0.856346 -0.856235 -0.809579 -0.857125 -0.855130 -0.282496 \n",
"7987 1.782063 1.826366 1.972431 1.814159 1.806921 0.243087 \n",
"\n",
" above_average_close Close_Next_Day \n",
"2484 -0.729840 5.700000 \n",
"1576 -0.729840 2.058594 \n",
"6595 1.370164 53.529999 \n",
"7412 1.370164 108.660004 \n",
"7413 1.370164 111.419998 \n",
"... ... ... \n",
"5519 1.370164 36.634998 \n",
"4531 -0.729840 13.660000 \n",
"535 -0.729840 0.906250 \n",
"787 -0.729840 1.187500 \n",
"7987 1.370164 91.010002 \n",
"\n",
"[6428 rows x 8 columns]"
]
},
"execution_count": 242,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Сводная таблица оценок качества для использованных моделей классификации\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__М а тр ица не то чно с те й__"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAQ9CAYAAACSpDaqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU9f8H8NdwI7BcciYi4kleeGTkmZJo5vHVvv40VFDTMry/nt+8NSm7TPNIy6v0a5eZmZnkmUrknQfirXgAKgKCAsvO/P4gtjZgYXVglp3X8/GYRzKfz85+ZpN9+Z7PHIIkSRKIiIiIiIhUzkrpARAREREREZkDFkdERERERERgcURERERERASAxREREREREREAFkdEREREREQAWBwREREREREBYHFEREREREQEgMURERERERERABZHREREREREAFgc0WNau3YtBEHA1atXK2T7V69ehSAIWLt2rSzb27t3LwRBwN69e2XZHhERkaWYPXs2BEEoV19BEDB79uyKHRCRglgckUVZtmyZbAUVEREREamLjdIDICpJYGAgHj16BFtbW5Net2zZMlSvXh3R0dEG69u3b49Hjx7Bzs5OxlESERFVfdOnT8fUqVOVHgaRWWBxRGZJEAQ4ODjItj0rKytZt0dERGQJcnJy4OTkBBsb/pOQCOBpdSSjZcuW4emnn4a9vT38/f0RExODjIyMYv2WLl2K2rVrw9HREc888wx+/fVXdOzYER07dtT3Kemao5SUFAwZMgQ1atSAvb09/Pz80KtXL/11T7Vq1cKZM2ewb98+CIIAQRD02yztmqOEhAS8+OKLcHd3h5OTE5o0aYKPPvpI3g+GiIjIDBRdW3T27Fm88sorcHd3R9u2bUu85igvLw/jx4+Hl5cXXFxc0LNnT9y4caPE7e7duxctW7aEg4MDgoOD8cknn5R6HdMXX3yBFi1awNHRER4eHujfvz+Sk5MrZH+JHgcPE5AsZs+ejTlz5iA8PBwjR45EUlISli9fjsOHD+PgwYP60+OWL1+OUaNGoV27dhg/fjyuXr2K3r17w93dHTVq1DD6Hn379sWZM2cwevRo1KpVC2lpaYiLi8P169dRq1YtLFq0CKNHj4azszPefPNNAICPj0+p24uLi8NLL70EPz8/jB07Fr6+vkhMTMS2bdswduxY+T4cIiIiM/Lvf/8bdevWxYIFCyBJEtLS0or1efXVV/HFF1/glVdewXPPPYfdu3eje/fuxfodP34cXbt2hZ+fH+bMmQOdToe5c+fCy8urWN+33noLM2bMQL9+/fDqq6/izp07WLJkCdq3b4/jx4/Dzc2tInaXyDQS0WNYs2aNBEC6cuWKlJaWJtnZ2UldunSRdDqdvs/HH38sAZBWr14tSZIk5eXlSZ6enlKrVq0krVar77d27VoJgNShQwf9uitXrkgApDVr1kiSJEn379+XAEjvvvuu0XE9/fTTBtspsmfPHgmAtGfPHkmSJKmgoEAKCgqSAgMDpfv37xv0FUWx/B8EERFRFTFr1iwJgDRgwIAS1xc5ceKEBEB64403DPq98sorEgBp1qxZ+nU9evSQqlWrJt28eVO/7sKFC5KNjY3BNq9evSpZW1tLb731lsE2T506JdnY2BRbT6QUnlZHT+yXX35Bfn4+xo0bByurv/5KDR8+HBqNBj/++CMA4MiRI7h37x6GDx9ucG5zZGQk3N3djb6Ho6Mj7OzssHfvXty/f/+Jx3z8+HFcuXIF48aNK3akqry3MyUiIqqKXn/9daPt27dvBwCMGTPGYP24ceMMftbpdPjll1/Qu3dv+Pv769fXqVMH3bp1M+i7efNmiKKIfv364e7du/rF19cXdevWxZ49e55gj4jkw9Pq6Ildu3YNAFC/fn2D9XZ2dqhdu7a+vei/derUMehnY2ODWrVqGX0Pe3t7vPPOO/jPf/4DHx8fPPvss3jppZcwePBg+Pr6mjzmS5cuAQAaNWpk8muJiIiqsqCgIKPt165dg5WVFYKDgw3W/zPn09LS8OjRo2K5DhTP+gsXLkCSJNStW7fE9zT17rREFYXFEVUZ48aNQ48ePbBlyxb8/PPPmDFjBmJjY7F7926EhoYqPTwiIqIqwdHRsdLfUxRFCIKAn376CdbW1sXanZ2dK31MRCXhaXX0xAIDAwEASUlJBuvz8/Nx5coVfXvRfy9evGjQr6CgQH/HubIEBwfjP//5D3bu3InTp08jPz8f77//vr69vKfEFR0NO336dLn6ExERqUVgYCBEUdSfZVHknznv7e0NBweHYrkOFM/64OBgSJKEoKAghIeHF1ueffZZ+XeE6DGwOKInFh4eDjs7OyxevBiSJOnXf/bZZ8jMzNTf3aZly5bw9PTEqlWrUFBQoO+3YcOGMq8jevjwIXJzcw3WBQcHw8XFBXl5efp1Tk5OJd4+/J+aN2+OoKAgLFq0qFj/v+8DERGR2hRdL7R48WKD9YsWLTL42draGuHh4diyZQtu3bqlX3/x4kX89NNPBn379OkDa2trzJkzp1jOSpKEe/fuybgHRI+Pp9XRE/Py8sK0adMwZ84cdO3aFT179kRSUhKWLVuGVq1aYeDAgQAKr0GaPXs2Ro8ejU6dOqFfv364evUq1q5di+DgYKOzPufPn0fnzp3Rr18/hISEwMbGBt999x1SU1PRv39/fb8WLVpg+fLlmD9/PurUqQNvb2906tSp2PasrKywfPly9OjRA82aNcOQIUPg5+eHc+fO4cyZM/j555/l/6CIiIiqgGbNmmHAgAFYtmwZMjMz8dxzz2HXrl0lzhDNnj0bO3fuRJs2bTBy5EjodDp8/PHHaNSoEU6cOKHvFxwcjPnz52PatGn6x3i4uLjgypUr+O677zBixAhMnDixEveSqGQsjkgWs2fPhpeXFz7++GOMHz8eHh4eGDFiBBYsWGBwkeWoUaMgSRLef/99TJw4EU2bNsXWrVsxZswYODg4lLr9gIAADBgwALt27cLnn38OGxsbNGjQAF999RX69u2r7zdz5kxcu3YNCxcuxIMHD9ChQ4cSiyMAiIiIwJ49ezBnzhy8//77EEURwcHBGD58uHwfDBERURW0evVqeHl5YcOGDdiyZQs6deqEH3/8EQEBAQb9WrRogZ9++gkTJ07EjBkzEBAQgLlz5yIxMRHnzp0z6Dt16lTUq1cPH374IebMmQOgMN+7dOmCnj17Vtq+ERkjSDyHiBQmiiK8vLzQp08frFq1SunhEBER0RPq3bs3zpw5gwsXLig9FCKT8JojqlS5ubnFzjVev3490tPT0bFjR2UGRURERI/t0aNHBj9fuHAB27dvZ65TlcSZI6pUe/fuxfjx4/Hvf/8bnp6eOHbsGD777DM0bNgQR48ehZ2dndJDJCIiIhP4+fkhOjpa/2zD5cuXIy8vD8ePHy/1uUZE5orXHFGlqlWrFgICArB48WKkp6fDw8MDgwcPxttvv83CiIiIqArq2rUr/ve//yElJQX29vYICwvDggULWBhRlcSZIyIiIiIiIvCaIyIiIiIiIgAsjoiIiIiIiADwmqNyEUURt27dgouLi9EHlRJZIkmS8ODBA/j7+8PKSt7jKbm5ucjPzy+zn52dndHnYBGR+jCbSc2YzRWHxVE53Lp1q9hDz4jUJjk5GTVq1JBte7m5uQgKdEZKmq7Mvr6+vrhy5YpFfgkT0eNhNhMxmysCi6NycHFxAQBcO1YLGmeeiaiEf9VrrPQQVKsAWhzAdv3vgVzy8/ORkqbDxSMB0LiU/nuV9UBEnZbJyM/Pt7gvYCJ6fMxm5TGblcNsrjgsjsqhaLpe42xl9C8KVRwbwVbpIajXn/ezrKjTVpxdBDi7lL5tETxdhoiKYzYrj9msIGZzhWFxRESK0ko6aI08UUAriZU4GiIiIlJzNrM4IiJFiZAgovQvYGNtREREJD81ZzPnoYlIUSIk6Iwspn4B79+/Hz169IC/vz8EQcCWLVsM2iVJwsyZM+Hn5wdHR0eEh4fjwoULBn3S09MRGRkJjUYDNzc3DBs2DNnZ2QZ9/vjjD7Rr1w4ODg4ICAjAwoULH2v/iYi
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значение 1049 в желтом квадрате представляет собой количество объектов, относимых к классу \"Less\", которые модель правильно классифицировала. Это свидетельствует о высоком уровне точности в идентификации этого класса.\n",
"Значение 558 в зеленом квадрате указывает на количество правильно классифицированных объектов класса \"More\". Хотя это также является положительным результатом, мы можем заметить, что он ниже, чем для класса \"Less\".\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Т о чно с ть, полнота, верность (аккуратность), F-ме р а __"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_f119e_row0_col0, #T_f119e_row0_col1, #T_f119e_row0_col2, #T_f119e_row1_col0, #T_f119e_row1_col1, #T_f119e_row1_col2, #T_f119e_row2_col0, #T_f119e_row2_col1, #T_f119e_row2_col2, #T_f119e_row3_col0, #T_f119e_row3_col1, #T_f119e_row3_col2, #T_f119e_row4_col0, #T_f119e_row4_col1, #T_f119e_row4_col2, #T_f119e_row5_col0, #T_f119e_row5_col1, #T_f119e_row5_col2, #T_f119e_row6_col0, #T_f119e_row6_col1, #T_f119e_row6_col2, #T_f119e_row7_col0, #T_f119e_row7_col1, #T_f119e_row7_col2 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_f119e_row0_col3, #T_f119e_row1_col3, #T_f119e_row2_col3, #T_f119e_row3_col3, #T_f119e_row4_col3, #T_f119e_row5_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_f119e_row0_col4, #T_f119e_row0_col6, #T_f119e_row1_col4, #T_f119e_row1_col6, #T_f119e_row2_col4, #T_f119e_row2_col6, #T_f119e_row3_col4, #T_f119e_row3_col6, #T_f119e_row4_col4, #T_f119e_row4_col6, #T_f119e_row5_col4, #T_f119e_row5_col6, #T_f119e_row6_col4, #T_f119e_row6_col6, #T_f119e_row7_col4, #T_f119e_row7_col6 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_f119e_row0_col5, #T_f119e_row0_col7, #T_f119e_row1_col5, #T_f119e_row1_col7, #T_f119e_row2_col5, #T_f119e_row2_col7, #T_f119e_row3_col5, #T_f119e_row3_col7, #T_f119e_row4_col5, #T_f119e_row4_col7, #T_f119e_row5_col5, #T_f119e_row5_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_f119e_row6_col3, #T_f119e_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_f119e_row6_col5, #T_f119e_row6_col7, #T_f119e_row7_col5, #T_f119e_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_f119e\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_f119e_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_f119e_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_f119e_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_f119e_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_f119e_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_f119e_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_f119e_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_f119e_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_f119e_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_f119e_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_f119e_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_f119e_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row4\" class=\"row_heading level0 row4\" >mlp</th>\n",
" <td id=\"T_f119e_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row5\" class=\"row_heading level0 row5\" >random_forest</th>\n",
" <td id=\"T_f119e_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col3\" class=\"data row5 col3\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col4\" class=\"data row5 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col5\" class=\"data row5 col5\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col6\" class=\"data row5 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row5_col7\" class=\"data row5 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row6\" class=\"row_heading level0 row6\" >decision_tree</th>\n",
" <td id=\"T_f119e_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row6_col3\" class=\"data row6 col3\" >0.998208</td>\n",
" <td id=\"T_f119e_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row6_col5\" class=\"data row6 col5\" >0.999378</td>\n",
" <td id=\"T_f119e_row6_col6\" class=\"data row6 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row6_col7\" class=\"data row6 col7\" >0.999103</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_f119e_level0_row7\" class=\"row_heading level0 row7\" >gradient_boosting</th>\n",
" <td id=\"T_f119e_row7_col0\" class=\"data row7 col0\" >1.000000</td>\n",
" <td id=\"T_f119e_row7_col1\" class=\"data row7 col1\" >1.000000</td>\n",
" <td id=\"T_f119e_row7_col2\" class=\"data row7 col2\" >1.000000</td>\n",
" <td id=\"T_f119e_row7_col3\" class=\"data row7 col3\" >0.998208</td>\n",
" <td id=\"T_f119e_row7_col4\" class=\"data row7 col4\" >1.000000</td>\n",
" <td id=\"T_f119e_row7_col5\" class=\"data row7 col5\" >0.999378</td>\n",
" <td id=\"T_f119e_row7_col6\" class=\"data row7 col6\" >1.000000</td>\n",
" <td id=\"T_f119e_row7_col7\" class=\"data row7 col7\" >0.999103</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x282db6e30b0>"
]
},
"execution_count": 247,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В с е модели, включая логистическую регрессию, ридж-регрессию, KNN, наивный байесовский классификатор, многослойную перцептронную сеть, случайный лес, дерево решений и градиентный бустинг, демонстрируют 100% точность (1.000000) на обучающей выборке.\n",
"Это указывает на то, что модели смогли полностью подстроиться под обучающие данные, что может стремительно указывать на возможное переобучение.\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__ROC-кривая, каппа Коэна, коэффициент корреляции М этьюс а __"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_0a1ef_row0_col0, #T_0a1ef_row0_col1, #T_0a1ef_row1_col0, #T_0a1ef_row1_col1, #T_0a1ef_row2_col0, #T_0a1ef_row2_col1, #T_0a1ef_row3_col0, #T_0a1ef_row3_col1, #T_0a1ef_row4_col0, #T_0a1ef_row4_col1, #T_0a1ef_row6_col0, #T_0a1ef_row6_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_0a1ef_row0_col2, #T_0a1ef_row0_col3, #T_0a1ef_row0_col4, #T_0a1ef_row1_col2, #T_0a1ef_row1_col3, #T_0a1ef_row1_col4, #T_0a1ef_row2_col2, #T_0a1ef_row2_col3, #T_0a1ef_row2_col4, #T_0a1ef_row3_col2, #T_0a1ef_row3_col3, #T_0a1ef_row3_col4, #T_0a1ef_row4_col2, #T_0a1ef_row4_col3, #T_0a1ef_row4_col4, #T_0a1ef_row5_col2, #T_0a1ef_row6_col2, #T_0a1ef_row6_col3, #T_0a1ef_row6_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0a1ef_row5_col0, #T_0a1ef_row5_col1, #T_0a1ef_row7_col0, #T_0a1ef_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_0a1ef_row5_col3, #T_0a1ef_row5_col4, #T_0a1ef_row7_col2, #T_0a1ef_row7_col3, #T_0a1ef_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_0a1ef\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_0a1ef_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_0a1ef_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_0a1ef_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_0a1ef_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_0a1ef_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_0a1ef_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_0a1ef_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_0a1ef_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row3\" class=\"row_heading level0 row3\" >naive_bayes</th>\n",
" <td id=\"T_0a1ef_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_0a1ef_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row5\" class=\"row_heading level0 row5\" >gradient_boosting</th>\n",
" <td id=\"T_0a1ef_row5_col0\" class=\"data row5 col0\" >0.999378</td>\n",
" <td id=\"T_0a1ef_row5_col1\" class=\"data row5 col1\" >0.999103</td>\n",
" <td id=\"T_0a1ef_row5_col2\" class=\"data row5 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row5_col3\" class=\"data row5 col3\" >0.998627</td>\n",
" <td id=\"T_0a1ef_row5_col4\" class=\"data row5 col4\" >0.998628</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_0a1ef_row6_col0\" class=\"data row6 col0\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row6_col1\" class=\"data row6 col1\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row6_col2\" class=\"data row6 col2\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row6_col3\" class=\"data row6 col3\" >1.000000</td>\n",
" <td id=\"T_0a1ef_row6_col4\" class=\"data row6 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_0a1ef_level0_row7\" class=\"row_heading level0 row7\" >decision_tree</th>\n",
" <td id=\"T_0a1ef_row7_col0\" class=\"data row7 col0\" >0.999378</td>\n",
" <td id=\"T_0a1ef_row7_col1\" class=\"data row7 col1\" >0.999103</td>\n",
" <td id=\"T_0a1ef_row7_col2\" class=\"data row7 col2\" >0.999104</td>\n",
" <td id=\"T_0a1ef_row7_col3\" class=\"data row7 col3\" >0.998627</td>\n",
" <td id=\"T_0a1ef_row7_col4\" class=\"data row7 col4\" >0.998628</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x282d6314410>"
]
},
"execution_count": 248,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Почти все модели, включая логистическую регрессию, ридж-регрессию, KNN, наивный байесовский классификатор, случайный лес и многослойную перцептронную сеть, достигли показателя ROC AUC равного 1.000000. Это говорит о том, что они идеально разделяют классы.\n",
"Градиентный бустинг и дерево решений немного уступили в значениях ROC AUC, составив 0.999378, что говорит о высокой, но не идеальной способности к классификации."
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Predicted</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Date, Predicted, Open, High, Low, Close, Adj Close, Volume, above_average_close, Close_Next_Day]\n",
"Index: []"
]
},
"execution_count": 250,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"above_average_close\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>2019-09-26</td>\n",
" <td>90.839996</td>\n",
" <td>91.150002</td>\n",
" <td>89.5</td>\n",
" <td>89.800003</td>\n",
" <td>81.286491</td>\n",
" <td>5026400</td>\n",
" <td>1</td>\n",
" <td>88.370003</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close Volume \\\n",
"6863 2019-09-26 90.839996 91.150002 89.5 89.800003 81.286491 5026400 \n",
"\n",
" above_average_close Close_Next_Day \n",
"6863 1 88.370003 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Close</th>\n",
" <th>Open</th>\n",
" <th>Adj Close</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Volume</th>\n",
" <th>above_average_close</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>1.77257</td>\n",
" <td>1.803818</td>\n",
" <td>1.716146</td>\n",
" <td>1.78857</td>\n",
" <td>1.788959</td>\n",
" <td>-0.702466</td>\n",
" <td>1.370164</td>\n",
" <td>88.370003</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Close Open Adj Close High Low Volume \\\n",
"6863 1.77257 1.803818 1.716146 1.78857 1.788959 -0.702466 \n",
"\n",
" above_average_close Close_Next_Day \n",
"6863 1.370164 88.370003 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [0. 1.])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 6863\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'log2',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 252,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__О б у че ние модели с новыми г ипе р па р а ме тр а ми__"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
2024-10-31 15:54:33 +04:00
"outputs": [],
2024-10-31 15:56:19 +04:00
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"log2\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Фо р мир о ва ние данных для оценки старой и новой версии мо де ли__"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__О це нка параметров старой и новой мо де ли__"
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_a2d75_row0_col0, #T_a2d75_row0_col1, #T_a2d75_row0_col2, #T_a2d75_row0_col3, #T_a2d75_row1_col0, #T_a2d75_row1_col1, #T_a2d75_row1_col2, #T_a2d75_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a2d75_row0_col4, #T_a2d75_row0_col5, #T_a2d75_row0_col6, #T_a2d75_row0_col7, #T_a2d75_row1_col4, #T_a2d75_row1_col5, #T_a2d75_row1_col6, #T_a2d75_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_a2d75\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_a2d75_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_a2d75_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_a2d75_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_a2d75_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_a2d75_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_a2d75_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_a2d75_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_a2d75_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" <th class=\"blank col5\" > </th>\n",
" <th class=\"blank col6\" > </th>\n",
" <th class=\"blank col7\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_a2d75_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_a2d75_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_a2d75_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a2d75_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_a2d75_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_a2d75_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x282f16d85c0>"
]
},
"execution_count": 255,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Как для обучающей (Precision_train), так и для тестовой (Precision_test) выборки о б е модели достигли идеальных значений 1.000000. Это указывает на то, что модели очень точно классифицируют положительные образцы, не пропуская их."
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_74bae_row0_col0, #T_74bae_row0_col1, #T_74bae_row1_col0, #T_74bae_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_74bae_row0_col2, #T_74bae_row0_col3, #T_74bae_row0_col4, #T_74bae_row1_col2, #T_74bae_row1_col3, #T_74bae_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_74bae\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" > </th>\n",
" <th id=\"T_74bae_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_74bae_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_74bae_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_74bae_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_74bae_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" > </th>\n",
" <th class=\"blank col1\" > </th>\n",
" <th class=\"blank col2\" > </th>\n",
" <th class=\"blank col3\" > </th>\n",
" <th class=\"blank col4\" > </th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_74bae_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_74bae_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_74bae_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_74bae_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_74bae_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_74bae_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_74bae_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_74bae_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_74bae_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_74bae_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_74bae_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_74bae_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x282d6ce6570>"
]
},
"execution_count": 256,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"О б а варианта модели продемонстрировали безупречную точность классификации, достигнув значения 1.000000. Это свидетельствует о том, что модели точно классифицировали все тестовые примеры, не допустив никаких ошибок в предсказаниях."
]
},
{
"cell_type": "code",
2024-11-01 17:49:40 +04:00
"execution_count": null,
2024-10-31 15:56:19 +04:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2kAAAGsCAYAAABHMu+IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABKIUlEQVR4nO3deXgV9d3+8fskIQshCwGSEA1hU0gqQoCKERSXSEBcKLb8sIBhKbRKVKCAWAUVlShWS0EKCrK18KB1oYqKIhYQiAgoLsgOCooJSExCwKxnfn9Qjh5ZnMOZ5Mw5eb+ua66HzEzmfCfNk9vPfD8z4zAMwxAAAAAAwBaCfD0AAAAAAMCPKNIAAAAAwEYo0gAAAADARijSAAAAAMBGKNIAAAAAwEYo0gAAAADARijSAAAAAMBGKNIAAAAAwEZCfD0AAMCZlZWVqaKiwrLjhYaGKjw83LLjAQDgCXLNPIo0ALChsrIytUhpoPzD1ZYdMzExUfv37w/YQAMA2Be55hmKNACwoYqKCuUfrtb+LSmKjvK+M73kmFMtOn2lioqKgAwzAIC9kWueoUgDABuLjgqyJMwAALADcs0cijQAsLFqw6lqw5rjAADga+SaORRpAGBjThlyyvs0s+IYAAB4i1wzh7lGAAAAALARZtIAwMaccsqKhg5rjgIAgHfINXMo0gDAxqoNQ9WG9y0dVhwDAABvkWvm0O4IAAAAADbCTBoA2Bg3WAMAAgm5Zg5FGgDYmFOGqgkzAECAINfMod0RAAAAAGyEmTQAsDHaQgAAgYRcM4eZNAAAAACwEYo0ALCxU48qtmLxxNq1a3XTTTcpKSlJDodDy5Ytc9tuGIYmTZqkpk2bKiIiQpmZmdq9e7fbPoWFhRowYICio6MVGxurYcOGqbS01G2fTz/9VFdeeaXCw8OVnJysqVOnntfPCQDgH3yVa5J/ZRtFGgDYmNPCxRPHjx9X+/btNXPmzDNunzp1qqZPn67Zs2dr48aNioyMVFZWlsrKylz7DBgwQNu2bdPKlSu1fPlyrV27ViNGjHBtLykpUY8ePZSSkqItW7boySef1EMPPaTnnnvOw9ECAPyFr3JN8q9scxhGgL8JDgD8UElJiWJiYrRje4Kiory/nnbsmFNtUwtUXFys6Ohoj77X4XDo1VdfVZ8+fSSdvNKYlJSkP//5zxo7dqwkqbi4WAkJCVqwYIH69++v7du3Ky0tTZs2bVLnzp0lSStWrNANN9ygr7/+WklJSZo1a5buv/9+5efnKzQ0VJI0YcIELVu2TDt27PD6nAEA9mGnXJPsn23MpAGAjVX/71HFVizSyZD86VJeXu7xmPbv36/8/HxlZma61sXExKhLly7Ky8uTJOXl5Sk2NtYVYpKUmZmpoKAgbdy40bXPVVdd5QoxScrKytLOnTv1/fffn9fPCwBgb3bMNcl+2UaRBgA2Vm1Yt0hScnKyYmJiXEtubq7HY8rPz5ckJSQkuK1PSEhwbcvPz1d8fLzb9pCQEMXFxbntc6Zj/PQzAACBxY65Jtkv23gEPwDUIQcPHnRrCwkLC/PhaAAA8E6g5hozaQBgY1bfYB0dHe22nE+YJSYmSpIKCgrc1hcUFLi2JSYm6vDhw27bq6qqVFhY6LbPmY7x088AAAQWO+aaZL9so0gDABtzyqFqCxanHJaNqUWLFkpMTNSqVatc60pKSrRx40ZlZGRIkjIyMlRUVKQtW7a49nnvvffkdDrVpUsX1z5r165VZWWla5+VK1eqTZs2atiwoWXjBQDYhx1zTbJftlGkAQBOU1paqq1bt2rr1q2STt5QvXXrVh04cEAOh0OjRo3So48+qtdee02fffaZbr/9diUlJbmekpWamqqePXtq+PDh+vDDD7V+/Xrl5OSof//+SkpKkiT9/ve/V2hoqIYNG6Zt27bphRde0N///neNGTPGR2cNAAhk/pRt3JMGADbmNE4uVhzHE5s3b9Y111zj+vpUuGRnZ2vBggUaP368jh8/rhEjRqioqEjdunXTihUrFB4e7vqexYsXKycnR9ddd52CgoJ06623avr06a7tMTExeueddzRy5Eh16tRJjRs31qRJk9zeNwMACCy+yjXJv7KN96QBgA2dep/M5m0JamDB+2RKjznV+Vfn/z4ZAAC8Qa55hpk0ALCxU733VhwHAABfI9fMoUgDABsjzAAAgYRcM4cHhwAAAACAjTCTBgA25jQcchreXy204hgAAHiLXDOHIg0AbIy2EABAICHXzKHdEQAAAABshJk0ALCxagWp2oLradUWjAUAAG+Ra+ZQpAGAjRkW9e4bAd67DwDwD+SaObQ7AgAAAICNMJMGADbGDdYAgEBCrplDkQYANlZtBKnasKB337BgMAAAeIlcM4d2RwAAAACwEWbSAMDGnHLIacH1NKcC/JIjAMAvkGvmMJMGAAAAADbCTBoA2Bg3WAMAAgm5Zg5FGgDYmHU3WAd2WwgAwD+Qa+bQ7ggAAAAANsJMGgDY2MkbrL1v6bDiGAAAeItcM4ciDQBszKkgVfMULABAgCDXzKHdEQAAAABshJk0ALAxbrAGAAQScs0cijQAsDGngnjpJwAgYJBr5tDuCAAAAAA2wkwaANhYteFQtWHBSz8tOAYAAN4i18xhJg0AAAAAbISZNACwsWqLHlVcHeC9+wAA/0CumUORBgA25jSC5LTgKVjOAH8KFgDAP5Br5tDuCAAAAAA2wkwaANgYbSEAgEBCrplDkQYANuaUNU+wcno/FAAAvEaumUO7IwAAAADYCDNpAGBjTgXJacH1NCuOAQCAt8g1cyjSAMDGqo0gVVvwFCwrjgEAgLfINXMC++wAAAAAwM8wkwYANuaUQ05ZcYO198cAAMBb5Jo5FGkAYGO0hQAAAgm5Zk5gnx0AAAAA+Blm0gDAxqx76SfX5AAAvkeumRPYZwcAAAAAfoaZNBOcTqcOHTqkqKgoORyBfZMiAO8ZhqFjx44pKSlJQUHeXQtzGg45DQtusLbgGAgc5BoAT5BrtY8izYRDhw4pOTnZ18MA4GcOHjyoCy+80KtjOC1qCwn0l37CM+QagPNBrtUeijQToqKiJElffdRc0Q0C+xcCnvvNxe18PQTYTJUqtU5vuv52AHZDruFcyDX8HLlW+yjSTDjVChLdIEjRUYQZ3IU46vl6CLAb4+T/saKNzGkEyWnBY4atOAYCB7mGcyHXcBpyrdZRpAGAjVXLoWoLXthpxTEAAPAWuWZOYJegAAAAAOBnmEkDABujLQQAEEjINXMo0gDAxqplTUtHtfdDAQDAa+SaOYFdggIAAACAn2EmDQBsjLYQAEAgIdfMCeyzAwAAAAA/w0waANhYtRGkaguuFlpxDAAAvEWumUORBgA2ZsghpwU3WBsB/j4ZAIB/INfMCewSFAAAAAD8DDNpAGBjtIUAAAIJuWYORRoA2JjTcMhpeN/SYcUxAADwFrlmTmCXoAAAAADgZ5hJAwAbq1aQqi24nmbFMQAA8Ba5Zg5FGgDYGG0hAIBAQq6ZE9glKAAAAAD4GWbSAMDGnAqS04LraVYcAwAAb5Fr5lCkAYCNVRsOVVvQ0mHFMQAA8Ba5Zk5gl6AAAAAA4GeYSQMAG+MGawBAICHXzGEmDQAAAABshJk0ALAxwwiS0/D+epphwTEAAPAWuWZOYJ8dAPi5ajksWzz63OpqTZw4US1atFBERIRatWqlRx55RIZhuPYxDEOTJk1S06ZNFRERoczMTO3evdvtOIWFhRowYICio6MVGxurYcOGqbS01JKfDQDA/5Br5lCkAQBO88QTT2jWrFl65plntH37dj3xxBOaOnWqZsyY4dpn6tSpmj59umbPnq2NGzcqMjJSWVlZKisrc+0zYMAAbdu2TStXrtTy5cu1du1ajRgxwhenBACow/wt12h3BAAbcxrW3BztNH55n5/asGGDbrnlFvXu3VuS1Lx5c/3f//2fPvzwQ0knrzZOmzZNDzzwgG655RZJ0qJFi5SQkKBly5apf//+2r59u1asWKFNmzapc+fOkqQZM2bohhtu0F//+lclJSV5fV4AAP9CrpnDTBoA2Jjzf73
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтом квадрате мы видим значение 1049, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Less\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В зеленом квадрате значение 558 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это также является показателем высокой точности модели в определении объектов данного класса."
]
2024-11-01 17:49:40 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Определение достижимого уровня качества модели для второй задачи (задача регрессии)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__2. Прогнозирование цены закрытия акций:__\n",
"\n",
"\n",
"Описание: Оценить, какая будет цена закрытия акций Starbucks на следующий день или через несколько дней на основе исторических данных.\n",
"Целевая переменная: Цена закрытия (Close). (среднее значение)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Загрузка данных и создание целевой переменной"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля 'Close': 30.058856538825285\n",
" Date Open High Low Close Adj Close Volume \\\n",
"0 1992-06-26 0.328125 0.347656 0.320313 0.335938 0.260703 224358400 \n",
"1 1992-06-29 0.339844 0.367188 0.332031 0.359375 0.278891 58732800 \n",
"2 1992-06-30 0.367188 0.371094 0.343750 0.347656 0.269797 34777600 \n",
"3 1992-07-01 0.351563 0.359375 0.339844 0.355469 0.275860 18316800 \n",
"4 1992-07-02 0.359375 0.359375 0.347656 0.355469 0.275860 13996800 \n",
"\n",
" above_average_close Close_Next_Day \n",
"0 0 0.359375 \n",
"1 0 0.347656 \n",
"2 0 0.355469 \n",
"3 0 0.355469 \n",
"4 0 0.355469 \n",
"Статистическое описание DataFrame:\n",
" Open High Low Close Adj Close \\\n",
"count 8035.000000 8035.000000 8035.000000 8035.000000 8035.000000 \n",
"mean 30.048051 30.345221 29.745172 30.052733 26.667480 \n",
"std 33.613031 33.904070 33.312079 33.613521 31.724640 \n",
"min 0.328125 0.347656 0.320313 0.335938 0.260703 \n",
"25% 4.391563 4.531250 4.304844 4.399219 3.413997 \n",
"50% 13.325000 13.485000 13.150000 13.330000 10.352452 \n",
"75% 55.250000 55.715000 54.829999 55.254999 47.461098 \n",
"max 126.080002 126.320000 124.809998 126.059998 118.010414 \n",
"\n",
" Volume above_average_close Close_Next_Day \n",
"count 8.035000e+03 8035.000000 8035.000000 \n",
"mean 1.470584e+07 0.347480 30.062556 \n",
"std 1.340058e+07 0.476199 33.616368 \n",
"min 1.504000e+06 0.000000 0.347656 \n",
"25% 7.818550e+06 0.000000 4.403125 \n",
"50% 1.170240e+07 0.000000 13.330000 \n",
"75% 1.778850e+07 1.000000 55.274999 \n",
"max 5.855088e+08 1.000000 126.059998 \n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Загрузка данных о ценах акций Starbucks из CSV файла\n",
"df = pd.read_csv(\".//static//csv//Starbucks Dataset.csv\")\n",
"\n",
"# Опция для настройки генерации случайных чисел (если это нужно для других частей кода)\n",
"random_state = 42\n",
"\n",
"# Вычисление среднего значения поля \"Close\"\n",
"average_close = df['Close'].mean()\n",
"print(f\"Среднее значение поля 'Close': {average_close}\")\n",
"\n",
"# Создание новой колонки, указывающей, выше или ниже среднего значение цена закрытия\n",
"df['above_average_close'] = (df['Close'] > average_close).astype(int)\n",
"\n",
"# Создание целевой переменной для прогнозирования (цена закрытия на следующий день)\n",
"df['Close_Next_Day'] = df['Close'].shift(-1)\n",
"\n",
"# Удаление последней строки, где нет значения для следующего дня\n",
"df.dropna(inplace=True)\n",
"\n",
"# Вывод DataFrame с новой колонкой\n",
"print(df.head())\n",
"\n",
"# Примерный анализ данных\n",
"print(\"Статистическое описание DataFrame:\")\n",
"print(df.describe())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи регрессии\n",
"\n",
"Целевой признак -- above_average_close"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5552</th>\n",
" <td>2014-07-14</td>\n",
" <td>39.490002</td>\n",
" <td>39.490002</td>\n",
" <td>39.209999</td>\n",
" <td>39.279999</td>\n",
" <td>32.493519</td>\n",
" <td>4562000</td>\n",
" <td>39.445000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3422</th>\n",
" <td>2006-01-25</td>\n",
" <td>15.340000</td>\n",
" <td>15.380000</td>\n",
" <td>15.095000</td>\n",
" <td>15.180000</td>\n",
" <td>11.780375</td>\n",
" <td>7276600</td>\n",
" <td>15.745000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6214</th>\n",
" <td>2017-02-28</td>\n",
" <td>56.709999</td>\n",
" <td>57.060001</td>\n",
" <td>56.549999</td>\n",
" <td>56.869999</td>\n",
" <td>48.946602</td>\n",
" <td>8750700</td>\n",
" <td>57.139999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3501</th>\n",
" <td>2006-05-18</td>\n",
" <td>18.225000</td>\n",
" <td>18.250000</td>\n",
" <td>17.965000</td>\n",
" <td>17.990000</td>\n",
" <td>13.961062</td>\n",
" <td>13366000</td>\n",
" <td>18.165001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2688</th>\n",
" <td>2003-02-26</td>\n",
" <td>5.657500</td>\n",
" <td>5.682500</td>\n",
" <td>5.520000</td>\n",
" <td>5.550000</td>\n",
" <td>4.307055</td>\n",
" <td>16738400</td>\n",
" <td>5.772500</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5226</th>\n",
" <td>2013-03-27</td>\n",
" <td>28.430000</td>\n",
" <td>28.475000</td>\n",
" <td>28.105000</td>\n",
" <td>28.455000</td>\n",
" <td>23.144903</td>\n",
" <td>7457000</td>\n",
" <td>28.475000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5390</th>\n",
" <td>2013-11-18</td>\n",
" <td>40.509998</td>\n",
" <td>40.669998</td>\n",
" <td>40.105000</td>\n",
" <td>40.270000</td>\n",
" <td>33.065239</td>\n",
" <td>8316400</td>\n",
" <td>39.959999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>1995-11-20</td>\n",
" <td>1.355469</td>\n",
" <td>1.367188</td>\n",
" <td>1.328125</td>\n",
" <td>1.332031</td>\n",
" <td>1.033717</td>\n",
" <td>30998400</td>\n",
" <td>1.343750</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7603</th>\n",
" <td>2022-09-02</td>\n",
" <td>85.470001</td>\n",
" <td>85.769997</td>\n",
" <td>82.550003</td>\n",
" <td>82.940002</td>\n",
" <td>79.683807</td>\n",
" <td>10336800</td>\n",
" <td>84.519997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7270</th>\n",
" <td>2021-05-10</td>\n",
" <td>114.570000</td>\n",
" <td>116.089996</td>\n",
" <td>114.209999</td>\n",
" <td>114.300003</td>\n",
" <td>106.577309</td>\n",
" <td>5759500</td>\n",
" <td>113.550003</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"5552 2014-07-14 39.490002 39.490002 39.209999 39.279999 32.493519 \n",
"3422 2006-01-25 15.340000 15.380000 15.095000 15.180000 11.780375 \n",
"6214 2017-02-28 56.709999 57.060001 56.549999 56.869999 48.946602 \n",
"3501 2006-05-18 18.225000 18.250000 17.965000 17.990000 13.961062 \n",
"2688 2003-02-26 5.657500 5.682500 5.520000 5.550000 4.307055 \n",
"... ... ... ... ... ... ... \n",
"5226 2013-03-27 28.430000 28.475000 28.105000 28.455000 23.144903 \n",
"5390 2013-11-18 40.509998 40.669998 40.105000 40.270000 33.065239 \n",
"860 1995-11-20 1.355469 1.367188 1.328125 1.332031 1.033717 \n",
"7603 2022-09-02 85.470001 85.769997 82.550003 82.940002 79.683807 \n",
"7270 2021-05-10 114.570000 116.089996 114.209999 114.300003 106.577309 \n",
"\n",
" Volume Close_Next_Day \n",
"5552 4562000 39.445000 \n",
"3422 7276600 15.745000 \n",
"6214 8750700 57.139999 \n",
"3501 13366000 18.165001 \n",
"2688 16738400 5.772500 \n",
"... ... ... \n",
"5226 7457000 28.475000 \n",
"5390 8316400 39.959999 \n",
"860 30998400 1.343750 \n",
"7603 10336800 84.519997 \n",
"7270 5759500 113.550003 \n",
"\n",
"[6428 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_close</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5552</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3422</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6214</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3501</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2688</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5226</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5390</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7603</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7270</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>6428 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_close\n",
"5552 1\n",
"3422 0\n",
"6214 1\n",
"3501 0\n",
"2688 0\n",
"... ...\n",
"5226 0\n",
"5390 1\n",
"860 0\n",
"7603 1\n",
"7270 1\n",
"\n",
"[6428 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Date</th>\n",
" <th>Open</th>\n",
" <th>High</th>\n",
" <th>Low</th>\n",
" <th>Close</th>\n",
" <th>Adj Close</th>\n",
" <th>Volume</th>\n",
" <th>Close_Next_Day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6637</th>\n",
" <td>2018-10-31</td>\n",
" <td>58.980000</td>\n",
" <td>59.119999</td>\n",
" <td>58.209999</td>\n",
" <td>58.270000</td>\n",
" <td>51.754456</td>\n",
" <td>11560400</td>\n",
" <td>58.630001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6632</th>\n",
" <td>2018-10-24</td>\n",
" <td>58.570000</td>\n",
" <td>59.279999</td>\n",
" <td>57.950001</td>\n",
" <td>58.060001</td>\n",
" <td>51.567940</td>\n",
" <td>12189700</td>\n",
" <td>58.959999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7327</th>\n",
" <td>2021-07-30</td>\n",
" <td>122.190002</td>\n",
" <td>122.980003</td>\n",
" <td>121.099998</td>\n",
" <td>121.430000</td>\n",
" <td>113.676071</td>\n",
" <td>5712300</td>\n",
" <td>120.370003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>730</th>\n",
" <td>1995-05-17</td>\n",
" <td>0.937500</td>\n",
" <td>0.941406</td>\n",
" <td>0.902344</td>\n",
" <td>0.910156</td>\n",
" <td>0.706323</td>\n",
" <td>25811200</td>\n",
" <td>0.912109</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1515</th>\n",
" <td>1998-06-25</td>\n",
" <td>3.226563</td>\n",
" <td>3.328125</td>\n",
" <td>3.218750</td>\n",
" <td>3.285156</td>\n",
" <td>2.549432</td>\n",
" <td>34699200</td>\n",
" <td>3.382813</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5777</th>\n",
" <td>2015-06-04</td>\n",
" <td>51.869999</td>\n",
" <td>52.180000</td>\n",
" <td>51.570000</td>\n",
" <td>51.720001</td>\n",
" <td>43.400497</td>\n",
" <td>6230800</td>\n",
" <td>52.189999</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7719</th>\n",
" <td>2023-02-21</td>\n",
" <td>105.500000</td>\n",
" <td>105.949997</td>\n",
" <td>104.709999</td>\n",
" <td>104.779999</td>\n",
" <td>101.752243</td>\n",
" <td>5438000</td>\n",
" <td>104.769997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1677</th>\n",
" <td>1999-02-17</td>\n",
" <td>2.972656</td>\n",
" <td>3.023438</td>\n",
" <td>2.906250</td>\n",
" <td>2.910156</td>\n",
" <td>2.258415</td>\n",
" <td>17776000</td>\n",
" <td>2.933594</td>\n",
" </tr>\n",
" <tr>\n",
" <th>921</th>\n",
" <td>1996-02-16</td>\n",
" <td>1.031250</td>\n",
" <td>1.054688</td>\n",
" <td>1.015625</td>\n",
" <td>1.031250</td>\n",
" <td>0.800297</td>\n",
" <td>7809600</td>\n",
" <td>1.031250</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>1993-10-05</td>\n",
" <td>0.835938</td>\n",
" <td>0.835938</td>\n",
" <td>0.804688</td>\n",
" <td>0.820313</td>\n",
" <td>0.636600</td>\n",
" <td>9113600</td>\n",
" <td>0.812500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1607 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Date Open High Low Close Adj Close \\\n",
"6637 2018-10-31 58.980000 59.119999 58.209999 58.270000 51.754456 \n",
"6632 2018-10-24 58.570000 59.279999 57.950001 58.060001 51.567940 \n",
"7327 2021-07-30 122.190002 122.980003 121.099998 121.430000 113.676071 \n",
"730 1995-05-17 0.937500 0.941406 0.902344 0.910156 0.706323 \n",
"1515 1998-06-25 3.226563 3.328125 3.218750 3.285156 2.549432 \n",
"... ... ... ... ... ... ... \n",
"5777 2015-06-04 51.869999 52.180000 51.570000 51.720001 43.400497 \n",
"7719 2023-02-21 105.500000 105.949997 104.709999 104.779999 101.752243 \n",
"1677 1999-02-17 2.972656 3.023438 2.906250 2.910156 2.258415 \n",
"921 1996-02-16 1.031250 1.054688 1.015625 1.031250 0.800297 \n",
"322 1993-10-05 0.835938 0.835938 0.804688 0.820313 0.636600 \n",
"\n",
" Volume Close_Next_Day \n",
"6637 11560400 58.630001 \n",
"6632 12189700 58.959999 \n",
"7327 5712300 120.370003 \n",
"730 25811200 0.912109 \n",
"1515 34699200 3.382813 \n",
"... ... ... \n",
"5777 6230800 52.189999 \n",
"7719 5438000 104.769997 \n",
"1677 17776000 2.933594 \n",
"921 7809600 1.031250 \n",
"322 9113600 0.812500 \n",
"\n",
"[1607 rows x 8 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_average_close</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6637</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6632</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7327</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>730</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1515</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5777</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7719</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1677</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>921</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>322</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1607 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_average_close\n",
"6637 1\n",
"6632 1\n",
"7327 1\n",
"730 0\n",
"1515 0\n",
"... ...\n",
"5777 1\n",
"7719 1\n",
"1677 0\n",
"921 0\n",
"322 0\n",
"\n",
"[1607 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"above_average_close\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"above_average_close\", \n",
" frac_train=0.8, \n",
" random_state=42 # Убедитесь, что вы задали нужное значение random_state\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование конвейера для решения задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"\n",
"class StarbucksFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" \n",
" def fit(self, X, y=None):\n",
" return self\n",
"\n",
" def transform(self, X, y=None):\n",
" X[\"Length_to_Width_Ratio\"] = X[\"x\"] / X[\"y\"]\n",
" return X\n",
"\n",
" def get_feature_names_out(self, features_in):\n",
" return np.append(features_in, [\"Length_to_Width_Ratio\"], axis=0)\n",
"\n",
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
"columns_to_drop = [\"Date\"]\n",
"num_columns = [\"Close\", \"Open\", \"Adj Close\", \"High\", \"Low\", \"Volume\"]\n",
"cat_columns = [] \n",
"\n",
"# Определяем предобработку для численных данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Определяем предобработку для категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Подготовка признаков с использованием ColumnTransformer\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Удаление нежелательных столбцов\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Постобработка признаков\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_cat\", preprocessing_cat, [\"Cabin_type\"]), \n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Создание окончательного конвейера\n",
"pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
" ]\n",
")\n",
"\n",
"# Использование конвейера\n",
"def train_pipeline(X, y):\n",
" pipeline.fit(X, y)\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Random Forest: Mean Score = 0.9746978079010529, Standard Deviation = 0.012793762025792637\n",
"Linear Regression: Mean Score = 0.9868838982543027, Standard Deviation = 0.0041016418339485\n",
"Gradient Boosting: Mean Score = 0.9790461912830413, Standard Deviation = 0.008537795226791314\n",
"Support Vector Regression: Mean Score = -0.10833533729231568, Standard Deviation = 0.29324311707552003\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"def train_multiple_models(X, y, models):\n",
" results = {}\n",
" for model_name, model in models.items():\n",
" # Создаем конвейер для каждой модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", model) # Используем текущую модель\n",
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=5) # 5-кратная кросс-валидация\n",
" results[model_name] = {\n",
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
" \n",
" return results\n",
"\n",
"models = {\n",
" \"Random Forest\": RandomForestRegressor(),\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
" \"Support Vector Regression\": SVR()\n",
"}\n",
"\n",
"results = train_multiple_models(X_train, y_train, models)\n",
"\n",
"# Вывод результатов\n",
"for model_name, scores in results.items():\n",
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лидирующие модели: Линейная регрессия проявила наилучшие результаты, за ней следует градиентный бустинг и Random Forest. Они продемонстрировали высокую эффективность в предсказании закрытия акций.\n",
"Проблемы SVR: Резкое отличие в результатах SVR выявляет необходимость более тщательной настройки или выбора других подходов к решению задачи, поскольку текущие параметры не обеспечили адекватного уровня прогноза."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение моделей на обучающем наборе данных и оценка на тестовом для регрессии"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: ridge\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: decision_tree\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: knn\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: naive_bayes\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: gradient_boosting\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: random_forest\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: mlp\n",
"MSE (train): 0.0020224019912881146\n",
"MSE (test): 0.0018656716417910447\n",
"MAE (train): 0.0020224019912881146\n",
"MAE (test): 0.0018656716417910447\n",
"R2 (train): 0.9911106856018297\n",
"R2 (test): 0.9918005898000289\n",
"STD (train): 0.044925626111093304\n",
"STD (test): 0.04315311009783723\n",
"----------------------------------------\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Проверка наличия необходимых переменных\n",
"if 'class_models' not in locals():\n",
" raise ValueError(\"class_models is not defined\")\n",
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
" raise ValueError(\"Train/test data is not defined\")\n",
"\n",
"\n",
"y_train = np.ravel(y_train) \n",
"y_test = np.ravel(y_test) \n",
"\n",
"# Инициализация списка для хранения результатов\n",
"results = []\n",
"\n",
"# Проход по моделям и оценка их качества\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" # Извлечение модели из словаря\n",
" model = class_models[model_name][\"model\"]\n",
" \n",
" # Создание пайплайна\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение пайплайна и предсказаний\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик для регрессии\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Дополнительные метрики\n",
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
"\n",
" # Вывод результатов для текущей модели\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
" print(\"-\" * 40) # Разделитель для разных моделей"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования обученной модели (конвейера регрессии) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: RandomForest\n",
"MSE (train): 0.0001403391412570006\n",
"MSE (test): 0.0006576851275668948\n",
"MAE (train): 0.0005491599253266957\n",
"MAE (test): 0.0011761045426260113\n",
"R2 (train): 0.9993811021756365\n",
"R2 (test): 0.9971008099591692\n",
"----------------------------------------\n",
"Прогноз: Цена закроется ниже среднего значения завтрашнего дня.\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestRegressor # пример модели\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# 1. Загрузка данных\n",
"data = pd.read_csv(\".//static//csv//Starbucks Dataset.csv\") \n",
"data['Date'] = pd.to_datetime(data['Date'])\n",
"data.set_index('Date', inplace=True)\n",
"\n",
"# 2. Подготовка данных для прогноза\n",
"data['Close_shifted'] = data['Close'].shift(-1) # Смещение на 1 день для предсказания\n",
"data.dropna(inplace=True) # Удаление NaN, возникших из-за смещения\n",
"\n",
"# Вычисляем среднее значение закрытия\n",
"average_close = data['Close'].mean()\n",
"data['above_average_close'] = (data['Close_shifted'] > average_close).astype(int) # 1, если выше среднего, иначе 0\n",
"\n",
"# Предикторы и целевая переменная\n",
"X = data[['Open', 'High', 'Low', 'Close', 'Volume']]\n",
"y = data['above_average_close']\n",
"\n",
"\n",
"# 3. Инициализация модели и пайплайна\n",
"class_models = {\n",
" \"RandomForest\": {\n",
" \"model\": RandomForestRegressor(n_estimators=100, random_state=42),\n",
" }\n",
"}\n",
"\n",
"pipeline_end = StandardScaler() \n",
"results = []\n",
"\n",
"# 4. Обучение модели и оценка\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" model = class_models[model_name][\"model\"]\n",
" model_pipeline = Pipeline([(\"scaler\", pipeline_end), (\"model\", model)])\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение результатов\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Вывод результатов\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(\"-\" * 40)\n",
"\n",
"# Прогнозирование выше среднего для следующего дня\n",
"latest_data = X_test.iloc[-1:].copy()\n",
"predicted_above_average = model_pipeline.predict(latest_data)\n",
"predicted_above_average = 1 if predicted_above_average[0] > 0.5 else 0 # Преобразуем в бинарный выход\n",
"\n",
"if predicted_above_average == 1:\n",
" print(\"Прогноз: Цена закроется выше среднего значения завтрашнего дня.\")\n",
"else:\n",
" print(\"Прогноз: Цена закроется ниже среднего значения завтрашнего дня.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n",
"Лучшие параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Лучший результат (MSE): 0.6848872116583115\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor # Используем р е г р е с с о р \n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"# 1. Подготовка данных для прогноза\n",
"data['above_average_close'] = data['Close'].shift(-1) # Смещение на 1 день для предсказания\n",
"data.dropna(inplace=True) # Удаление NaN, возникших из-за смещения\n",
"\n",
"# Предикторы и целевая переменная\n",
"X = data[['Open', 'High', 'Low', 'Close', 'Volume']]\n",
"y = data['above_average_close'] # Целевая переменная для регрессии\n",
"\n",
"# Делим данные на обучающую и тестовую выборки\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# 2. Создание и настройка модели случайного леса\n",
"model = RandomForestRegressor() # Изменяем на р е г р е с с о р \n",
"\n",
"# Установка параметров для поиска по сетке\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# 3. Подбор гиперпараметров с помощью Grid Search\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# 4. Результаты подбора гиперпараметров\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_) # Меняем знак, так как берем отрицательное значение среднеквадратичной ошибки\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n",
"Старые параметры: {'max_depth': 10, 'min_samples_split': 5, 'n_estimators': 200}\n",
"Лучший результат (MSE) на старых параметрах: 0.688662233031193\n",
"\n",
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Лучший результат (MSE) на новых параметрах: 0.6794717145705662\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.5876131198171756\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.7665592735184772\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHWCAYAAABACtmGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABqPUlEQVR4nO3deVwWVf//8fcFyr4qgmvgvu+m4pLeuWCZuZRr5pqVmkuoqWXg8jU011LTtFKzzKXMSkvrRs19X3PLfQd3UFRQmN8f/rhuLwEvLgRBfD0fj+shc+bMmc+MA8OHc+aMyTAMQwAAAACAFNlldgAAAAAAkNWROAEAAACAFSROAAAAAGAFiRMAAAAAWEHiBAAAAABWkDgBAAAAgBUkTgAAAABgBYkTAAAAAFhB4gQAAAAAVpA4AQAAAIAVJE7AU+rYsWN65513VKRIETk5OcnDw0O1a9fWZ599ptu3b2d2eM+MNWvWyGQyyWQy6bvvvku2Tu3atWUymVSuXDmL8ri4OH322WeqXLmyPDw85OXlpbJly+rtt9/WoUOHzPXmzJlj3kdyn82bN2foMQIAAClHZgcAwHbLly9X69at5ejoqE6dOqlcuXKKi4vT+vXrNWjQIO3fv18zZ87M7DCfKU5OTpo/f746duxoUX7y5Elt3LhRTk5OSbZ57bXX9Mcff6h9+/bq0aOH7t69q0OHDmnZsmWqVauWSpUqZVF/5MiRKly4cJJ2ihUrlr4HAwAAkiBxAp4yJ06cULt27eTv769Vq1YpX7585nW9e/fW0aNHtXz58kyM8Nn08ssv69dff9Xly5fl4+NjLp8/f778/PxUvHhxXbt2zVy+bds2LVu2TKNHj9aHH35o0dbUqVN1/fr1JPt46aWXVK1atQw7BgAAkDKG6gFPmU8//VQ3b97U119/bZE0JSpWrJj69etnXjaZTHrvvff0/fffq2TJknJyclLVqlW1du1ai+1OnTqlXr16qWTJknJ2dlbu3LnVunVrnTx50qLew8PGXFxcVL58eX311VcW9bp06SI3N7ck8f34448ymUxas2aNRfmWLVvUpEkTeXp6ysXFRfXq1dOGDRss6gwfPlwmk0mXL1+2KN++fbtMJpPmzJljsf+AgACLemfOnJGzs7NMJlOS4/rjjz9Ut25dubq6yt3dXU2bNtX+/fuTxJ+S5s2by9HRUYsXL7Yonz9/vtq0aSN7e3uL8mPHjkm6P4zvYfb29sqdO3eq950aJ0+eTHGo38PnQpLq16+fbN0Hz7EkTZ8+XeXKlZOLi4tFvR9//NFqTOfOnVP37t2VP39+OTo6qnDhwurZs6fi4uKsDk98MJa9e/eqS5cu5mGrefPmVbdu3XTlyhWL/SVeP4cOHVKbNm3k4eGh3Llzq1+/frpz545F3cTvm5Qkxpd47latWiU7OzuFhIRY1Js/f75MJpOmT5/+yHNRv3591a9f36Js27Zt5mO1pn79+kmGgkrS+PHjk/0//uKLL1S2bFk5Ojoqf/786t27d5Jk/eFrwMfHR02bNtU///xjUS8zztWjrosHj/WXX35R06ZNzddY0aJFNWrUKMXHxydps1y5ctqxY4dq1aolZ2dnFS5cWDNmzLCoFxcXp5CQEFWtWlWenp5ydXVV3bp1tXr1aot6D36/LV261GLdnTt35O3tLZPJpPHjx1usO3funLp16yY/Pz85OjqqbNmy+uabb8zrHxwanNJn+PDhkmy73u/du6dRo0apaNGicnR0VEBAgD788EPFxsZa1AsICDDvx87OTnnz5lXbtm11+vTpR/6fAdkFPU7AU+a3335TkSJFVKtWrVRv8/fff2vhwoXq27evHB0d9cUXX6hJkybaunWr+Zetbdu2aePGjWrXrp0KFiyokydPavr06apfv74OHDggFxcXizYnTZokHx8fRUdH65tvvlGPHj0UEBCghg0b2nxMq1at0ksvvaSqVasqNDRUdnZ2mj17tl588UWtW7dO1atXt7nN5ISEhCT5hUGS5s2bp86dOysoKEhjx47VrVu3NH36dNWpU0e7du1KkoAlx8XFRc2bN9cPP/ygnj17SpL27Nmj/fv366uvvtLevXst6vv7+0uSvv/+e9WuXVs5clj/cRwVFZUkaTSZTDYlWe3bt9fLL78sSfr999/1ww8/pFi3VKlS+uijjyRJly9f1vvvv2+xfuHCherVq5fq16+vPn36yNXVVQcPHtQnn3xiNY7z58+revXqun79ut5++22VKlVK586d048//qhbt27phRde0Lx588z1R48eLUnmeCSZvwf++usvHT9+XF27dlXevHnNQ1X379+vzZs3J0k82rRpo4CAAIWFhWnz5s36/PPPde3aNX377bdW407Jiy++qF69eiksLEwtWrRQlSpVdOHCBfXp00cNGzbUu+++a3ObgwcPTnM8jzJ8+HCNGDFCDRs2VM+ePXX48GFNnz5d27Zt04YNG5QzZ05z3cRrwDAMHTt2TBMnTtTLL7/8WL8op8e5KliwoMLCwizKkrue58yZIzc3NwUHB8vNzU2rVq1SSEiIoqOjNW7cOIu6165d08svv6w2bdqoffv2WrRokXr27CkHBwd169ZNkhQdHa2vvvrKPLz2xo0b+vrrrxUUFKStW7eqUqVKFm06OTlp9uzZatGihblsyZIlyf4cioyMVM2aNc2JaJ48efTHH3+oe/fuio6OVv/+/VW6dGmL74uZM2fq4MGDmjRpkrmsQoUKFu2m5np/6623NHfuXL3++usaMGCAtmzZorCwMB08eFA///yzRXt169bV22+/rYSEBP3zzz+aPHmyzp8/r3Xr1iU5JiDbMQA8NaKiogxJRvPmzVO9jSRDkrF9+3Zz2alTpwwnJyejZcuW5rJbt24l2XbTpk2GJOPbb781l82ePduQZJw4ccJc9u+//xqSjE8//dRc1rlzZ8PV1TVJm4sXLzYkGatXrzYMwzASEhKM4sWLG0FBQUZCQoJFPIULFzYaNWpkLgsNDTUkGZcuXbJoc9u2bYYkY/bs2Rb79/f3Ny//888/hp2dnfHSSy9ZxH/jxg3Dy8vL6NGjh0WbERERhqenZ5Lyh61evdqQZCxevNhYtmyZYTKZjNOnTxuGYRiDBg0yihQpYhiGYdSrV88oW7asebuEhASjXr16hiTDz8/PaN++vTFt2jTj1KlTSfaReM6T+zg6Oj4yvkSJ/0fjx483l40bNy7J/2Wi2rVrG//5z3/MyydOnEhyjtu3b294eXkZt2/fTvZ8PEqnTp0MOzs7Y9u2bUnWPXgdJKpXr55Rr169ZNtK7tr94YcfDEnG2rVrzWWJ18+rr75qUbdXr16GJGPPnj3mMklG7969U4w/ue+DmJgYo1ixYkbZsmWNO3fuGE2bNjU8PDyS/T+1dny///67Iclo0qSJkZpb9cPXV6KH/48vXrxoODg4GI0bNzbi4+PN9aZOnWpIMr755psUYzIMw/jwww8NScbFixfNZZlxrlJzrIaR/LXxzjvvGC4uLsadO3cs2pRkTJgwwVwWGxtrVKpUyfD19TXi4uIMwzCMe/fuGbGxsRbtXbt2zfDz8zO6detmLkv8fmnfvr2RI0cOIyIiwryuQYMGRocOHQxJxrhx48zl3bt3N/Lly2dcvnzZov127doZnp6eyR7Lwz/nHpTa63337t2GJOOtt96yqDdw4EBDkrFq1Spzmb+/v9G5c2eLeh06dDBcXFySjQHIbhiqBzxFoqOjJUnu7u42bRcYGKiqVaual5977jk1b95cK1euNA9ZcXZ2Nq+/e/eurly5omLFisnLy0s7d+5M0ua1a9d0+fJlHT9+XJMmTZK9vb3q1auXpN7ly5ctPjdu3LBYv3v3bh05ckQdOnTQlStXzPViYmLUoEEDrV27VgkJCRbbXL161aLNqKgoq+dg6NChqlKlilq3bm1R/tdff+n69etq3769RZv29vaqUaNGkiE4j9K4cWPlypVLCxYskGEYWrBggdq3b59sXZPJpJUrV+r//u//5O3trR9++EG9e/eWv7+/2rZtm+wzTtOmTdNff/1l8fnjjz9SFVv
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"# 1. Настройка параметров для старых значений\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 5, 10] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"# 2. Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"# 3. Настройка параметров для новых значений\n",
"new_param_grid = {\n",
" 'n_estimators': [200],\n",
" 'max_depth': [10],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"# 4. Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nН о вые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.bar(['Старые параметры', 'Новые параметры'], [old_best_mse, new_best_mse], color=['blue', 'orange'])\n",
"plt.xlabel('Подбор параметров')\n",
"plt.ylabel('Среднеквадратическая ошибка (MSE)')\n",
"plt.title('Сравнение MSE для старых и новых параметров')\n",
2024-11-07 10:58:14 +04:00
"plt.show()\n",
"# надеюсь, все..."
2024-11-01 17:49:40 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сравнив результаты с использованием старых и новых параметров, наблюдается, что новые параметры модели позволили добиться меньшей среднеквадратической ошибки, что указывает на более эффективное предсказание по сравнению с о старыми настройками. Значение RMSE на тестовых данных также подтверждает улучшение качества модели, так как оно стало меньше и указывает на более точные прогнозы по сравнению с предыдущими настройками."
]
2024-10-31 15:54:33 +04:00
}
],
"metadata": {
2024-10-31 15:56:19 +04:00
"kernelspec": {
"display_name": "aimenv",
"language": "python",
"name": "python3"
},
2024-10-31 15:54:33 +04:00
"language_info": {
2024-10-31 15:56:19 +04:00
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
2024-10-31 15:54:33 +04:00
}
},
"nbformat": 4,
"nbformat_minor": 2
}