4370 lines
513 KiB
Plaintext
Raw Normal View History

2024-11-15 22:35:48 +04:00
{
"cells": [
{
"cell_type": "code",
2024-11-29 00:30:06 +04:00
"execution_count": null,
2024-11-15 22:35:48 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Определим бизнес цели:\n",
2024-11-29 00:30:06 +04:00
"## 1- Прогнозирование возраста миллиардера(классификация)\n",
"## 2- Прогнозирование состояния миллиардера(регрессия)"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-15 22:37:33 +04:00
"# Подготовим данные: категоризируем колонку age"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-15 22:37:33 +04:00
"Rank 0\n",
"Name 0\n",
"Networth 0\n",
"Age 0\n",
"Country 0\n",
"Source 0\n",
"Industry 0\n",
"dtype: int64\n",
"\n",
"Rank False\n",
"Name False\n",
"Networth False\n",
"Age False\n",
"Country False\n",
"Source False\n",
"Industry False\n",
"dtype: bool\n",
"\n"
2024-11-15 22:35:48 +04:00
]
}
],
"source": [
2024-11-15 22:37:33 +04:00
"print(df.isnull().sum())\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"print()\n",
"\n",
"# Есть ли пустые значения признаков\n",
"print(df.isnull().any())\n",
"\n",
"print()\n",
"\n",
"# Процент пустых значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
2024-11-15 22:37:33 +04:00
"execution_count": 2,
2024-11-15 22:35:48 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-15 22:37:33 +04:00
" Rank Name Networth Country \\\n",
"0 1 Elon Musk 219.0 United States \n",
"1 2 Jeff Bezos 171.0 United States \n",
"2 3 Bernard Arnault & family 158.0 France \n",
"3 4 Bill Gates 129.0 United States \n",
"4 5 Warren Buffett 118.0 United States \n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
" Source Industry Age_category \n",
"0 Tesla, SpaceX Automotive 50-60 \n",
"1 Amazon Technology 50-60 \n",
"2 LVMH Fashion & Retail 70-80 \n",
"3 Microsoft Technology 60-70 \n",
"4 Berkshire Hathaway Finance & Investments 80+ \n"
2024-11-15 22:35:48 +04:00
]
}
],
"source": [
"\n",
"\n",
2024-11-15 22:37:33 +04:00
"bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n",
2024-11-15 22:35:48 +04:00
"labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n",
"\n",
2024-11-15 22:37:33 +04:00
"df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n",
"# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n",
"df.drop(columns=['Age'], inplace=True)\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"# Просмотр результата\n",
"print(df.head())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Name</th>\n",
" <th>Networth</th>\n",
" <th>Country</th>\n",
" <th>Source</th>\n",
" <th>Industry</th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1909</th>\n",
" <td>1818</td>\n",
" <td>Tran Ba Duong &amp; family</td>\n",
" <td>1.6</td>\n",
" <td>Vietnam</td>\n",
" <td>automotive</td>\n",
" <td>Automotive</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2099</th>\n",
" <td>2076</td>\n",
" <td>Mark Dixon</td>\n",
" <td>1.4</td>\n",
" <td>United Kingdom</td>\n",
" <td>office real estate</td>\n",
" <td>Real Estate</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1392</th>\n",
" <td>1341</td>\n",
" <td>Yingzhuo Xu</td>\n",
" <td>2.3</td>\n",
" <td>China</td>\n",
" <td>agribusiness</td>\n",
" <td>Food &amp; Beverage</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>627</th>\n",
" <td>622</td>\n",
" <td>Bruce Flatt</td>\n",
" <td>4.6</td>\n",
" <td>Canada</td>\n",
" <td>money management</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>527</th>\n",
" <td>523</td>\n",
" <td>Li Liangbin</td>\n",
" <td>5.2</td>\n",
" <td>China</td>\n",
" <td>lithium</td>\n",
" <td>Manufacturing</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>85</td>\n",
" <td>Theo Albrecht, Jr. &amp; family</td>\n",
" <td>18.7</td>\n",
" <td>Germany</td>\n",
" <td>Aldi, Trader Joe's</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>633</th>\n",
" <td>622</td>\n",
" <td>Tony Tamer</td>\n",
" <td>4.6</td>\n",
" <td>United States</td>\n",
" <td>private equity</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>922</th>\n",
" <td>913</td>\n",
" <td>Bob Gaglardi</td>\n",
" <td>3.3</td>\n",
" <td>Canada</td>\n",
" <td>hotels</td>\n",
" <td>Real Estate</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2178</th>\n",
" <td>2076</td>\n",
" <td>Eugene Wu</td>\n",
" <td>1.4</td>\n",
" <td>Taiwan</td>\n",
" <td>finance</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>415</th>\n",
" <td>411</td>\n",
" <td>Leonard Stern</td>\n",
" <td>6.2</td>\n",
" <td>United States</td>\n",
" <td>real estate</td>\n",
" <td>Real Estate</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2080 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" Rank Name Networth Country \\\n",
"1909 1818 Tran Ba Duong & family 1.6 Vietnam \n",
"2099 2076 Mark Dixon 1.4 United Kingdom \n",
"1392 1341 Yingzhuo Xu 2.3 China \n",
"627 622 Bruce Flatt 4.6 Canada \n",
"527 523 Li Liangbin 5.2 China \n",
"... ... ... ... ... \n",
"84 85 Theo Albrecht, Jr. & family 18.7 Germany \n",
"633 622 Tony Tamer 4.6 United States \n",
"922 913 Bob Gaglardi 3.3 Canada \n",
"2178 2076 Eugene Wu 1.4 Taiwan \n",
"415 411 Leonard Stern 6.2 United States \n",
"\n",
" Source Industry Age_category \n",
"1909 automotive Automotive 60-70 \n",
"2099 office real estate Real Estate 60-70 \n",
"1392 agribusiness Food & Beverage 50-60 \n",
"627 money management Finance & Investments 50-60 \n",
"527 lithium Manufacturing 50-60 \n",
"... ... ... ... \n",
"84 Aldi, Trader Joe's Fashion & Retail 70-80 \n",
"633 private equity Finance & Investments 60-70 \n",
"922 hotels Real Estate 80+ \n",
"2178 finance Finance & Investments 70-80 \n",
"415 real estate Real Estate 80+ \n",
"\n",
"[2080 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1909</th>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2099</th>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1392</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>627</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>527</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>633</th>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>922</th>\n",
" <td>80+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2178</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>415</th>\n",
" <td>80+</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2080 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Age_category\n",
"1909 60-70\n",
"2099 60-70\n",
"1392 50-60\n",
"627 50-60\n",
"527 50-60\n",
"... ...\n",
"84 70-80\n",
"633 60-70\n",
"922 80+\n",
"2178 70-80\n",
"415 80+\n",
"\n",
"[2080 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Name</th>\n",
" <th>Networth</th>\n",
" <th>Country</th>\n",
" <th>Source</th>\n",
" <th>Industry</th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2075</th>\n",
" <td>2076</td>\n",
" <td>Radhe Shyam Agarwal</td>\n",
" <td>1.4</td>\n",
" <td>India</td>\n",
" <td>consumer goods</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1529</th>\n",
" <td>1513</td>\n",
" <td>Robert Duggan</td>\n",
" <td>2.0</td>\n",
" <td>United States</td>\n",
" <td>pharmaceuticals</td>\n",
" <td>Healthcare</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1803</th>\n",
" <td>1729</td>\n",
" <td>Yao Kuizhang</td>\n",
" <td>1.7</td>\n",
" <td>China</td>\n",
" <td>beverages</td>\n",
" <td>Food &amp; Beverage</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>424</td>\n",
" <td>Alexei Kuzmichev</td>\n",
" <td>6.0</td>\n",
" <td>Russia</td>\n",
" <td>oil, banking, telecom</td>\n",
" <td>Energy</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2597</th>\n",
" <td>2578</td>\n",
" <td>Ramesh Genomal</td>\n",
" <td>1.0</td>\n",
" <td>Philippines</td>\n",
" <td>apparel</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>935</th>\n",
" <td>913</td>\n",
" <td>Alfred Oetker</td>\n",
" <td>3.3</td>\n",
" <td>Germany</td>\n",
" <td>consumer goods</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1541</th>\n",
" <td>1513</td>\n",
" <td>Thomas Lee</td>\n",
" <td>2.0</td>\n",
" <td>United States</td>\n",
" <td>private equity</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1646</th>\n",
" <td>1645</td>\n",
" <td>Roberto Angelini Rossi</td>\n",
" <td>1.8</td>\n",
" <td>Chile</td>\n",
" <td>forestry, mining</td>\n",
" <td>diversified</td>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>376</th>\n",
" <td>375</td>\n",
" <td>Patrick Drahi</td>\n",
" <td>6.6</td>\n",
" <td>France</td>\n",
" <td>telecom</td>\n",
" <td>Telecom</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1894</th>\n",
" <td>1818</td>\n",
" <td>Gerald Schwartz</td>\n",
" <td>1.6</td>\n",
" <td>Canada</td>\n",
" <td>finance</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>520 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" Rank Name Networth Country \\\n",
"2075 2076 Radhe Shyam Agarwal 1.4 India \n",
"1529 1513 Robert Duggan 2.0 United States \n",
"1803 1729 Yao Kuizhang 1.7 China \n",
"425 424 Alexei Kuzmichev 6.0 Russia \n",
"2597 2578 Ramesh Genomal 1.0 Philippines \n",
"... ... ... ... ... \n",
"935 913 Alfred Oetker 3.3 Germany \n",
"1541 1513 Thomas Lee 2.0 United States \n",
"1646 1645 Roberto Angelini Rossi 1.8 Chile \n",
"376 375 Patrick Drahi 6.6 France \n",
"1894 1818 Gerald Schwartz 1.6 Canada \n",
"\n",
" Source Industry Age_category \n",
"2075 consumer goods Fashion & Retail 70-80 \n",
"1529 pharmaceuticals Healthcare 70-80 \n",
"1803 beverages Food & Beverage 50-60 \n",
"425 oil, banking, telecom Energy 50-60 \n",
"2597 apparel Fashion & Retail 70-80 \n",
"... ... ... ... \n",
"935 consumer goods Fashion & Retail 50-60 \n",
"1541 private equity Finance & Investments 70-80 \n",
"1646 forestry, mining diversified 70-80 \n",
"376 telecom Telecom 50-60 \n",
"1894 finance Finance & Investments 80+ \n",
"\n",
"[520 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2075</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1529</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1803</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>425</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2597</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>935</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1541</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1646</th>\n",
" <td>70-80</td>\n",
" </tr>\n",
" <tr>\n",
" <th>376</th>\n",
" <td>50-60</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1894</th>\n",
" <td>80+</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>520 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Age_category\n",
"2075 70-80\n",
"1529 70-80\n",
"1803 50-60\n",
"425 50-60\n",
"2597 70-80\n",
"... ...\n",
"935 50-60\n",
"1541 70-80\n",
"1646 70-80\n",
"376 50-60\n",
"1894 80+\n",
"\n",
"[520 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Age_category\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-15 22:37:33 +04:00
"# Формирование конвейера для классификации данных\n",
"## preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"## preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"## features_preprocessing -- трансформер для предобработки признаков\n",
"## features_engineering -- трансформер для конструирования признаков\n",
"## drop_columns -- трансформер для удаления колонок\n",
"## pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
2024-11-15 23:33:34 +04:00
"execution_count": 37,
2024-11-15 22:37:33 +04:00
"metadata": {},
2024-11-15 23:33:34 +04:00
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prepocessing_num__Networth</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1909</th>\n",
" <td>-0.309917</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2099</th>\n",
" <td>-0.329245</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1392</th>\n",
" <td>-0.242268</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>627</th>\n",
" <td>-0.019995</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>527</th>\n",
" <td>0.037990</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>1.342637</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>633</th>\n",
" <td>-0.019995</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>922</th>\n",
" <td>-0.145628</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2178</th>\n",
" <td>-0.329245</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>415</th>\n",
" <td>0.134630</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2080 rows × 860 columns</p>\n",
"</div>"
],
"text/plain": [
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
"1909 -0.309917 0.0 \n",
"2099 -0.329245 0.0 \n",
"1392 -0.242268 0.0 \n",
"627 -0.019995 0.0 \n",
"527 0.037990 0.0 \n",
"... ... ... \n",
"84 1.342637 0.0 \n",
"633 -0.019995 0.0 \n",
"922 -0.145628 0.0 \n",
"2178 -0.329245 0.0 \n",
"415 0.134630 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 1.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 1.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"1909 ... 0.0 \n",
"2099 ... 0.0 \n",
"1392 ... 0.0 \n",
"627 ... 0.0 \n",
"527 ... 0.0 \n",
"... ... ... \n",
"84 ... 0.0 \n",
"633 ... 0.0 \n",
"922 ... 0.0 \n",
"2178 ... 0.0 \n",
"415 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 1.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"1909 0.0 \n",
"2099 1.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 1.0 \n",
"2178 0.0 \n",
"415 1.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
"[2080 rows x 860 columns]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
2024-11-15 22:37:33 +04:00
"source": [
"\n",
"from sklearn.compose import ColumnTransformer\n",
2024-11-15 23:33:34 +04:00
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
2024-11-15 22:37:33 +04:00
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
2024-11-15 23:33:34 +04:00
"import pandas as pd\n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-15 23:33:34 +04:00
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = [\"Age_category\", \"Rank \", \"Name\"]\n",
2024-11-15 22:37:33 +04:00
"\n",
"num_columns = [\n",
" column\n",
2024-11-15 23:33:34 +04:00
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
2024-11-15 22:37:33 +04:00
"]\n",
"cat_columns = [\n",
" column\n",
2024-11-15 23:33:34 +04:00
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
2024-11-15 22:37:33 +04:00
"]\n",
"\n",
2024-11-15 23:33:34 +04:00
"# Предобработка числовых данных\n",
2024-11-15 22:37:33 +04:00
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
2024-11-15 23:33:34 +04:00
"# Предобработка категориальных данных\n",
2024-11-15 22:37:33 +04:00
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
2024-11-15 23:33:34 +04:00
"# Общая предобработка признаков\n",
2024-11-15 22:37:33 +04:00
"features_preprocessing = ColumnTransformer(\n",
2024-11-15 23:33:34 +04:00
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
2024-11-15 22:37:33 +04:00
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
2024-11-15 23:33:34 +04:00
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
2024-11-15 22:37:33 +04:00
")\n",
"\n",
2024-11-15 23:33:34 +04:00
"# Итоговый конвейер\n",
2024-11-15 22:37:33 +04:00
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
2024-11-15 23:33:34 +04:00
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование набора моделей для классификации\n",
"## logistic -- логистическая регрессия\n",
"## ridge -- гребневая регрессия\n",
"## decision_tree -- дерево решений\n",
"## knn -- k-ближайших соседей\n",
"## naive_bayes -- наивный Байесовский классификатор\n",
"## gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"## random_forest -- метод случайного леса (набор деревьев решений)\n",
"## mlp -- многослойный персептрон (нейронная сеть)"
2024-11-15 22:37:33 +04:00
]
},
{
"cell_type": "code",
2024-11-15 23:33:34 +04:00
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
2024-11-16 01:14:29 +04:00
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"y_train['Age_category'] = y_train['Age_category'].cat.codes\n",
"y_test['Age_category'] = y_test['Age_category'].cat.codes"
]
},
{
"cell_type": "code",
"execution_count": 44,
2024-11-15 22:35:48 +04:00
"metadata": {},
"outputs": [
2024-11-15 23:33:34 +04:00
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-11-16 01:14:29 +04:00
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n",
2024-11-15 23:33:34 +04:00
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
2024-11-16 01:14:29 +04:00
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
2024-11-15 23:33:34 +04:00
]
},
2024-11-15 22:35:48 +04:00
{
2024-11-16 01:14:29 +04:00
"name": "stdout",
"output_type": "stream",
"text": [
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\neighbors\\_classification.py:238: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return self._fit(X, y)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: decision_tree\n",
"Model: knn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: naive_bayes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_label.py:114: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: gradient_boosting\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: random_forest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:1105: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
2024-11-15 22:35:48 +04:00
]
}
],
"source": [
2024-11-15 23:33:34 +04:00
"import numpy as np\n",
"from sklearn import metrics\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 23:33:34 +04:00
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
2024-11-16 01:14:29 +04:00
" model_pipeline = model_pipeline.fit(X_train, y_train)\n",
2024-11-15 23:33:34 +04:00
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
2024-11-16 01:14:29 +04:00
" y_test_probs = model_pipeline.predict_proba(X_test)\n",
" y_test_predict = np.argmax(y_test_probs, axis=1)\n",
2024-11-15 23:33:34 +04:00
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
2024-11-16 01:14:29 +04:00
" # Метрики\n",
2024-11-15 23:33:34 +04:00
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
2024-11-16 01:14:29 +04:00
" y_train, y_train_predict, average=\"macro\"\n",
2024-11-15 23:33:34 +04:00
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
2024-11-16 01:14:29 +04:00
" y_test, y_test_predict, average=\"macro\"\n",
2024-11-15 23:33:34 +04:00
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
2024-11-16 01:14:29 +04:00
" y_train, y_train_predict, average=\"macro\"\n",
2024-11-15 23:33:34 +04:00
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
2024-11-16 01:14:29 +04:00
" y_test, y_test_predict, average=\"macro\"\n",
2024-11-15 23:33:34 +04:00
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
2024-11-16 01:14:29 +04:00
" y_test, y_test_probs, multi_class=\"ovr\"\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(\n",
" y_train, y_train_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(\n",
" y_test, y_test_predict, average=\"macro\"\n",
2024-11-15 23:33:34 +04:00
" )\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
2024-11-15 22:35:48 +04:00
]
2024-11-16 01:14:29 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABL4AAAb5CAYAAABKIMnxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hT1R8G8DdN26Qz3YtuaCkFKlAZZQnIlCkVFPGHbMGCUMWBsmUoiiAKqIggKiI4EGRvWUUoe5VVaKGL7t2mSX5/VAKxpTRp6c1t38/z3EdzV9+EtPnek3POlWg0Gg2IiIiIiIiIiIhqGROhAxARERERERERET0JbPgiIiIiIiIiIqJaiQ1fRERERERERERUK7Hhi4iIiIiIiIiIaiU2fBERERERERERUa3Ehi8iIiIiIiIiIqqV2PBFRERERERERES1Ehu+iIiIiIiIiIioVmLDFxERERERERER1Ups+CISqTVr1kAikeDWrVtP5Py3bt2CRCLBmjVrquV8Bw4cgEQiwYEDB6rlfERERERUt8yaNQsSiaRS+0okEsyaNevJBiIiUWDDFxFVq+XLl1dbYxkRERERERFRVZgKHYCIjJOPjw8KCgpgZmam13HLly+Hk5MThg8frrO+Y8eOKCgogLm5eTWmJCIiIqK6Ytq0aXjvvfeEjkFEIsOGLyIql0QigVwur7bzmZiYVOv5iIiIiKjuyMvLg5WVFUxNeQlLRPrhUEeiWmT58uVo3LgxZDIZPDw8EBERgczMzDL7LVu2DP7+/rCwsECrVq1w6NAhdOrUCZ06ddLuU94cX0lJSRgxYgQ8PT0hk8ng7u6O/v37a+cZ8/X1xcWLF3Hw4EFIJBJIJBLtOR81x9fx48fx3HPPwd7eHlZWVggJCcHnn39evS8MEREREYnG/bm8Ll26hJdffhn29vZo3759uXN8FRUVITIyEs7OzrCxsUG/fv1w586dcs974MABPP3005DL5ahfvz6+/vrrR84b9uOPPyI0NBQWFhZwcHDASy+9hPj4+CfyfInoyWJzOVEtMWvWLMyePRtdu3bF+PHjERMTgxUrVuDEiRM4cuSIdsjiihUrMGHCBHTo0AGRkZG4desWBgwYAHt7e3h6elb4M8LDw3Hx4kVMnDgRvr6+SElJwe7duxEXFwdfX18sWbIEEydOhLW1NT744AMAgKur6yPPt3v3bvTp0wfu7u6YNGkS3NzccPnyZfz111+YNGlS9b04RERERCQ6gwYNQkBAAObPnw+NRoOUlJQy+4wePRo//vgjXn75ZbRt2xb79u1D7969y+x3+vRp9OzZE+7u7pg9ezZUKhXmzJkDZ2fnMvvOmzcP06dPx+DBgzF69Gjcu3cPX3zxBTp27IjTp0/Dzs7uSTxdInpC2PBFVAvcu3cPCxYsQPfu3bF9+3aYmJR25gwKCsKECRPw448/YsSIESguLsb06dPRsmVL7Nu3T9tVPCQkBMOHD6+w4SszMxNHjx7FJ598gilTpmjXT506Vfv/AwYMwLRp0+Dk5IRXXnmlwswqlQqvvfYa3N3dcebMGZ0CQqPRGPIyEBEREVEt8tRTT2HdunXax/+9S+PZs2fx448/4vXXX8eyZcsAABERERg6dCjOnTuns+/MmTMhlUpx5MgReHh4AAAGDx6MRo0a6ex3+/ZtzJw5E3PnzsX777+vXT9w4EA0b94cy5cv11lPRMaPQx2JaoE9e/aguLgYkydP1jZ6AcCYMWNga2uLrVu3AgBOnjyJtLQ0jBkzRmd+hKFDh8Le3r7Cn2FhYQFzc3McOHAAGRkZVc58+vRpxMbGYvLkyWW+NavsbaqJiIiIqPYaN25chdu3bdsGAHjjjTd01k+ePFnnsUqlwp49ezBgwABtoxcANGjQAL169dLZ9/fff4darcbgwYORmpqqXdzc3BAQEID9+/dX4RkRkRDY44uoFrh9+zYAoGHDhjrrzc3N4e/vr91+/78NGjTQ2c/U1BS+vr4V/gyZTIaPP/4Yb731FlxdXdGmTRv06dMHw4YNg5ubm96Zb9y4AQBo0qSJ3scSERERUe3n5+dX4fbbt2/DxMQE9evX11n/35o4JSUFBQUFZWpgoGxdfO3aNWg0GgQEBJT7M/W94zkRCY8NX0RUaZMnT0bfvn2xadMm7Ny5E9OnT8eCBQuwb98+NG/eXOh4RERERFSLWFhY1PjPVKvVkEgk2L59O6RSaZnt1tbWNZ6JiKqGQx2JagEfHx8AQExMjM764uJixMbGarff/+/169d19ispKdHemfFx6tevj7feegu7du3ChQsXUFxcjEWLFmm3V3aY4v1v5i5cuFCp/YmIiIiIHubj4wO1Wq0dSXDff2tiFxcXyOXyMjUwULYurl+/PjQaDfz8/NC1a9cyS5s2bar/iRDRE8WGL6JaoGvXrjA3N8fSpUt1JoZftWoVsrKytHe2efrpp+Ho6IiVK1eipKREu99PP/302Hm78vPzUVhYqLOufv36sLGxQVFRkXadlZUVMjMzH5u5RYsW8PPzw5IlS8rsz8ntiYiIiOhx7s/PtXTpUp31S5Ys0XkslUrRtWtXbNq0CQkJCdr1169fx/bt23X2HThwIKRSKWbPnl2mJtVoNEhLS6vGZ0BENYFDHYlqAWdnZ0ydOhWzZ89Gz5490a9fP8TExGD58uVo2bKl9g6L5ubmmDVrFiZOnIguXbpg8ODBuHXrFtasWYP69etX2Fvr6tWrePbZZzF48GAEBwfD1NQUf/zxB5KTk/HSSy9p9wsNDcWKFSswd+5cNGjQAC4uLujSpUuZ85mYmGDFihXo27cvmjVrhhEjRsDd3R1XrlzBxYsXsXPnzup/oYiIiIio1mjWrBmGDBmC5cuXIysrC23btsXevXvL7dk1a9Ys7Nq1C+3atcP48eOhUqnw5ZdfokmTJjhz5ox2v/r162Pu3LmYOnUqbt26hQEDBsDGxgaxsbH4448/MHbsWJ07nBOR8WPDF1EtMWvWLDg7O+PLL79EZGQkHBwcMHbsWMyfP19nEs4JEyZAo9Fg0aJFmDJlCp566ils3rwZb7zxBuRy+SPP7+XlhSFDhmDv3r344YcfYGpqiqCgIGzYsAHh4eHa/WbMmIHbt29j4cKFyMnJwTPPPFNuwxcA9OjRA/v378fs2bOxaNEiqNVq1K9fH2PGjKm+F4aIiIiIaq3vvvsOzs7O+Omnn7Bp0yZ06dIFW7duhZeXl85+oaGh2L59O6ZMmYLp06fDy8sLc+bMweXLl3HlyhWdfd977z0EBgZi8eLFmD17NoDSWrh79+7o169fjT03IqoeEg3HFBHVeWq1Gs7Ozhg4cCBWrlwpdBwiIiIiohoxYMAAXLx4EdeuXRM6ChE9IZzji6iOKSwsLDNfwdq1a5Geno5OnToJE4qIiIiI6AkrKCjQeXzt2jVs27aNNTBRLcceX0R1zIEDBxAZGYlBgwbB0dERp06dwqpVq9CoUSNER0fD3Nxc6IhERERERNXO3d0dw4cPh7+/P27fvo0VK1agqKgIp0+fRkBAgNDxiOgJ4RxfRHWMr68vvLy8sHTpUqSnp8PBwQHDhg3DRx99xEYvIiIiIqq1evbsiZ9//hlJSUmQyWQICwvD/Pnz2ehFVMuxxxcREREREREREdVKnOOLiIiIiIiIiIhqJTZ8ERERERERERFRrcQ5voyUWq1GQkICbGxsIJFIhI5DRGQUNBoNcnJy4OHhAROTmv/uprCwEMXFxVU6h7m5OeRyeTUlIiIyPqxjiYjKErqOBapey4q1jmXDl5FKSEiAl5eX0DGIiIxSfHw8PD09a/RnFhYWws/HGkkpqiqdx83NDbGxsaIsGoiIKoN1LBHRowlRxwLVU8uKtY5lw5eRsrGxAQC0x3MwhZnAaUgMJGbivSOjRlm1HjSkP7G+X0o0Shwq2aT9G1mTiouLkZSiQmy0D2xtDPuWLjtHDb/Q2yguLhZdwUBEVFmsY0lfYq1LANaxQhDr+0XIOhaoei0r5jqWDV9G6n63cFOYwVTCgoEeTyLi94lGwpvL1jQxv18ACDp0xsq6dDGEim91IqoDWMe
"text/plain": [
"<Figure size 1700x1700 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(17, 17), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Under 30\", \"30-40\", \"40-50\", \"50-60\", \"60-70\", \"70-80\", \"80+\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_cdcb0_row0_col0 {\n",
" background-color: #50c46a;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row0_col1 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row0_col2 {\n",
" background-color: #32b67a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row0_col3 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row0_col4, #T_cdcb0_row3_col5 {\n",
" background-color: #ca457a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row0_col5, #T_cdcb0_row1_col4, #T_cdcb0_row1_col6, #T_cdcb0_row1_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row0_col6, #T_cdcb0_row7_col6 {\n",
" background-color: #ab2494;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row0_col7 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row1_col0, #T_cdcb0_row1_col1, #T_cdcb0_row4_col2, #T_cdcb0_row5_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row1_col2 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row1_col3, #T_cdcb0_row7_col2 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row1_col5 {\n",
" background-color: #d5536f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col0 {\n",
" background-color: #38b977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col1 {\n",
" background-color: #60ca60;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row2_col2 {\n",
" background-color: #35b779;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col3, #T_cdcb0_row4_col0 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row2_col4, #T_cdcb0_row7_col4 {\n",
" background-color: #a82296;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col5, #T_cdcb0_row4_col4 {\n",
" background-color: #cc4778;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col6 {\n",
" background-color: #aa2395;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row2_col7 {\n",
" background-color: #cb4679;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col0 {\n",
" background-color: #5ec962;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row3_col1 {\n",
" background-color: #31b57b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col2 {\n",
" background-color: #20928c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col3 {\n",
" background-color: #24aa83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col4 {\n",
" background-color: #8405a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col6 {\n",
" background-color: #7801a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row3_col7 {\n",
" background-color: #9e199d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row4_col1, #T_cdcb0_row5_col1 {\n",
" background-color: #70cf57;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row4_col3 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row4_col5 {\n",
" background-color: #b22b8f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row4_col6 {\n",
" background-color: #c33d80;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row4_col7 {\n",
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row5_col0 {\n",
" background-color: #42be71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row5_col2 {\n",
" background-color: #95d840;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row5_col4 {\n",
" background-color: #ba3388;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row5_col5 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row5_col6 {\n",
" background-color: #bb3488;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row5_col7 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row6_col0, #T_cdcb0_row6_col1, #T_cdcb0_row6_col2, #T_cdcb0_row6_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row6_col4, #T_cdcb0_row6_col6, #T_cdcb0_row6_col7, #T_cdcb0_row7_col5 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row6_col5 {\n",
" background-color: #9410a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row7_col0 {\n",
" background-color: #44bf70;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row7_col1 {\n",
" background-color: #73d056;\n",
" color: #000000;\n",
"}\n",
"#T_cdcb0_row7_col3 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_cdcb0_row7_col7 {\n",
" background-color: #a21d9a;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_cdcb0\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_cdcb0_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_cdcb0_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_cdcb0_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_cdcb0_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_cdcb0_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_cdcb0_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_cdcb0_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_cdcb0_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_cdcb0_row0_col0\" class=\"data row0 col0\" >0.567471</td>\n",
" <td id=\"T_cdcb0_row0_col1\" class=\"data row0 col1\" >0.278074</td>\n",
" <td id=\"T_cdcb0_row0_col2\" class=\"data row0 col2\" >0.444883</td>\n",
" <td id=\"T_cdcb0_row0_col3\" class=\"data row0 col3\" >0.232769</td>\n",
" <td id=\"T_cdcb0_row0_col4\" class=\"data row0 col4\" >0.618269</td>\n",
" <td id=\"T_cdcb0_row0_col5\" class=\"data row0 col5\" >0.353846</td>\n",
" <td id=\"T_cdcb0_row0_col6\" class=\"data row0 col6\" >0.465219</td>\n",
" <td id=\"T_cdcb0_row0_col7\" class=\"data row0 col7\" >0.237759</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_cdcb0_row1_col0\" class=\"data row1 col0\" >0.836061</td>\n",
" <td id=\"T_cdcb0_row1_col1\" class=\"data row1 col1\" >0.287405</td>\n",
" <td id=\"T_cdcb0_row1_col2\" class=\"data row1 col2\" >0.725411</td>\n",
" <td id=\"T_cdcb0_row1_col3\" class=\"data row1 col3\" >0.235795</td>\n",
" <td id=\"T_cdcb0_row1_col4\" class=\"data row1 col4\" >0.689904</td>\n",
" <td id=\"T_cdcb0_row1_col5\" class=\"data row1 col5\" >0.344231</td>\n",
" <td id=\"T_cdcb0_row1_col6\" class=\"data row1 col6\" >0.760847</td>\n",
" <td id=\"T_cdcb0_row1_col7\" class=\"data row1 col7\" >0.240251</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row2\" class=\"row_heading level0 row2\" >knn</th>\n",
" <td id=\"T_cdcb0_row2_col0\" class=\"data row2 col0\" >0.477783</td>\n",
" <td id=\"T_cdcb0_row2_col1\" class=\"data row2 col1\" >0.221788</td>\n",
" <td id=\"T_cdcb0_row2_col2\" class=\"data row2 col2\" >0.460090</td>\n",
" <td id=\"T_cdcb0_row2_col3\" class=\"data row2 col3\" >0.214239</td>\n",
" <td id=\"T_cdcb0_row2_col4\" class=\"data row2 col4\" >0.497115</td>\n",
" <td id=\"T_cdcb0_row2_col5\" class=\"data row2 col5\" >0.328846</td>\n",
" <td id=\"T_cdcb0_row2_col6\" class=\"data row2 col6\" >0.456182</td>\n",
" <td id=\"T_cdcb0_row2_col7\" class=\"data row2 col7\" >0.211556</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row3\" class=\"row_heading level0 row3\" >decision_tree</th>\n",
" <td id=\"T_cdcb0_row3_col0\" class=\"data row3 col0\" >0.618281</td>\n",
" <td id=\"T_cdcb0_row3_col1\" class=\"data row3 col1\" >0.163157</td>\n",
" <td id=\"T_cdcb0_row3_col2\" class=\"data row3 col2\" >0.244223</td>\n",
" <td id=\"T_cdcb0_row3_col3\" class=\"data row3 col3\" >0.184231</td>\n",
" <td id=\"T_cdcb0_row3_col4\" class=\"data row3 col4\" >0.387981</td>\n",
" <td id=\"T_cdcb0_row3_col5\" class=\"data row3 col5\" >0.325000</td>\n",
" <td id=\"T_cdcb0_row3_col6\" class=\"data row3 col6\" >0.227570</td>\n",
" <td id=\"T_cdcb0_row3_col7\" class=\"data row3 col7\" >0.146479</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_cdcb0_row4_col0\" class=\"data row4 col0\" >0.581578</td>\n",
" <td id=\"T_cdcb0_row4_col1\" class=\"data row4 col1\" >0.236539</td>\n",
" <td id=\"T_cdcb0_row4_col2\" class=\"data row4 col2\" >0.735419</td>\n",
" <td id=\"T_cdcb0_row4_col3\" class=\"data row4 col3\" >0.246556</td>\n",
" <td id=\"T_cdcb0_row4_col4\" class=\"data row4 col4\" >0.627404</td>\n",
" <td id=\"T_cdcb0_row4_col5\" class=\"data row4 col5\" >0.288462</td>\n",
" <td id=\"T_cdcb0_row4_col6\" class=\"data row4 col6\" >0.599765</td>\n",
" <td id=\"T_cdcb0_row4_col7\" class=\"data row4 col7\" >0.231541</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_cdcb0_row5_col0\" class=\"data row5 col0\" >0.518033</td>\n",
" <td id=\"T_cdcb0_row5_col1\" class=\"data row5 col1\" >0.238462</td>\n",
" <td id=\"T_cdcb0_row5_col2\" class=\"data row5 col2\" >0.695673</td>\n",
" <td id=\"T_cdcb0_row5_col3\" class=\"data row5 col3\" >0.247678</td>\n",
" <td id=\"T_cdcb0_row5_col4\" class=\"data row5 col4\" >0.556250</td>\n",
" <td id=\"T_cdcb0_row5_col5\" class=\"data row5 col5\" >0.284615</td>\n",
" <td id=\"T_cdcb0_row5_col6\" class=\"data row5 col6\" >0.553233</td>\n",
" <td id=\"T_cdcb0_row5_col7\" class=\"data row5 col7\" >0.226955</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row6\" class=\"row_heading level0 row6\" >mlp</th>\n",
" <td id=\"T_cdcb0_row6_col0\" class=\"data row6 col0\" >0.035714</td>\n",
" <td id=\"T_cdcb0_row6_col1\" class=\"data row6 col1\" >0.035714</td>\n",
" <td id=\"T_cdcb0_row6_col2\" class=\"data row6 col2\" >0.142857</td>\n",
" <td id=\"T_cdcb0_row6_col3\" class=\"data row6 col3\" >0.142857</td>\n",
" <td id=\"T_cdcb0_row6_col4\" class=\"data row6 col4\" >0.250000</td>\n",
" <td id=\"T_cdcb0_row6_col5\" class=\"data row6 col5\" >0.250000</td>\n",
" <td id=\"T_cdcb0_row6_col6\" class=\"data row6 col6\" >0.057143</td>\n",
" <td id=\"T_cdcb0_row6_col7\" class=\"data row6 col7\" >0.057143</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_cdcb0_level0_row7\" class=\"row_heading level0 row7\" >naive_bayes</th>\n",
" <td id=\"T_cdcb0_row7_col0\" class=\"data row7 col0\" >0.524162</td>\n",
" <td id=\"T_cdcb0_row7_col1\" class=\"data row7 col1\" >0.239277</td>\n",
" <td id=\"T_cdcb0_row7_col2\" class=\"data row7 col2\" >0.664585</td>\n",
" <td id=\"T_cdcb0_row7_col3\" class=\"data row7 col3\" >0.202308</td>\n",
" <td id=\"T_cdcb0_row7_col4\" class=\"data row7 col4\" >0.494231</td>\n",
" <td id=\"T_cdcb0_row7_col5\" class=\"data row7 col5\" >0.176923</td>\n",
" <td id=\"T_cdcb0_row7_col6\" class=\"data row7 col6\" >0.465319</td>\n",
" <td id=\"T_cdcb0_row7_col7\" class=\"data row7 col7\" >0.151713</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dc070ca660>"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## значения далеки от идела, датасет так себе..."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting Jinja2\n",
" Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)\n",
"Collecting MarkupSafe>=2.0 (from Jinja2)\n",
" Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl.metadata (4.1 kB)\n",
"Downloading jinja2-3.1.4-py3-none-any.whl (133 kB)\n",
"Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl (15 kB)\n",
"Installing collected packages: MarkupSafe, Jinja2\n",
"Successfully installed Jinja2-3.1.4 MarkupSafe-3.0.2\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"[notice] A new release of pip is available: 24.2 -> 24.3.1\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
]
}
],
"source": [
"pip install Jinja2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_d129d_row0_col0 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row0_col1, #T_d129d_row1_col0 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row0_col2, #T_d129d_row1_col3, #T_d129d_row1_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row0_col3 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row0_col4 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row1_col1 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row1_col2 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row2_col0 {\n",
" background-color: #44bf70;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row2_col1 {\n",
" background-color: #93d741;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row2_col2 {\n",
" background-color: #ae2892;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row2_col3, #T_d129d_row3_col4, #T_d129d_row4_col4 {\n",
" background-color: #c7427c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row2_col4 {\n",
" background-color: #c9447a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row3_col0 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row3_col1 {\n",
" background-color: #7cd250;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row3_col2, #T_d129d_row5_col3, #T_d129d_row6_col4 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row3_col3 {\n",
" background-color: #c6417d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row4_col0 {\n",
" background-color: #48c16e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row4_col1 {\n",
" background-color: #9bd93c;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row4_col2 {\n",
" background-color: #9c179e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row4_col3 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row5_col0 {\n",
" background-color: #7ad151;\n",
" color: #000000;\n",
"}\n",
"#T_d129d_row5_col1 {\n",
" background-color: #2eb37c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row5_col2 {\n",
" background-color: #7e03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row5_col4 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row6_col0, #T_d129d_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row6_col1 {\n",
" background-color: #32b67a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row6_col2 {\n",
" background-color: #5b01a5;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row6_col3 {\n",
" background-color: #9814a0;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row7_col0 {\n",
" background-color: #25ac82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_d129d_row7_col2, #T_d129d_row7_col3, #T_d129d_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_d129d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_d129d_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_d129d_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_d129d_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_d129d_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_d129d_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row0\" class=\"row_heading level0 row0\" >gradient_boosting</th>\n",
" <td id=\"T_d129d_row0_col0\" class=\"data row0 col0\" >0.344231</td>\n",
" <td id=\"T_d129d_row0_col1\" class=\"data row0 col1\" >0.240251</td>\n",
" <td id=\"T_d129d_row0_col2\" class=\"data row0 col2\" >0.649816</td>\n",
" <td id=\"T_d129d_row0_col3\" class=\"data row0 col3\" >0.131708</td>\n",
" <td id=\"T_d129d_row0_col4\" class=\"data row0 col4\" >0.138628</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row1\" class=\"row_heading level0 row1\" >logistic</th>\n",
" <td id=\"T_d129d_row1_col0\" class=\"data row1 col0\" >0.353846</td>\n",
" <td id=\"T_d129d_row1_col1\" class=\"data row1 col1\" >0.237759</td>\n",
" <td id=\"T_d129d_row1_col2\" class=\"data row1 col2\" >0.615478</td>\n",
" <td id=\"T_d129d_row1_col3\" class=\"data row1 col3\" >0.160238</td>\n",
" <td id=\"T_d129d_row1_col4\" class=\"data row1 col4\" >0.161282</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row2\" class=\"row_heading level0 row2\" >ridge</th>\n",
" <td id=\"T_d129d_row2_col0\" class=\"data row2 col0\" >0.284615</td>\n",
" <td id=\"T_d129d_row2_col1\" class=\"data row2 col1\" >0.226955</td>\n",
" <td id=\"T_d129d_row2_col2\" class=\"data row2 col2\" >0.612260</td>\n",
" <td id=\"T_d129d_row2_col3\" class=\"data row2 col3\" >0.129672</td>\n",
" <td id=\"T_d129d_row2_col4\" class=\"data row2 col4\" >0.133551</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_d129d_row3_col0\" class=\"data row3 col0\" >0.328846</td>\n",
" <td id=\"T_d129d_row3_col1\" class=\"data row3 col1\" >0.211556</td>\n",
" <td id=\"T_d129d_row3_col2\" class=\"data row3 col2\" >0.602333</td>\n",
" <td id=\"T_d129d_row3_col3\" class=\"data row3 col3\" >0.128794</td>\n",
" <td id=\"T_d129d_row3_col4\" class=\"data row3 col4\" >0.130205</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_d129d_row4_col0\" class=\"data row4 col0\" >0.288462</td>\n",
" <td id=\"T_d129d_row4_col1\" class=\"data row4 col1\" >0.231541</td>\n",
" <td id=\"T_d129d_row4_col2\" class=\"data row4 col2\" >0.599541</td>\n",
" <td id=\"T_d129d_row4_col3\" class=\"data row4 col3\" >0.126828</td>\n",
" <td id=\"T_d129d_row4_col4\" class=\"data row4 col4\" >0.129917</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row5\" class=\"row_heading level0 row5\" >decision_tree</th>\n",
" <td id=\"T_d129d_row5_col0\" class=\"data row5 col0\" >0.325000</td>\n",
" <td id=\"T_d129d_row5_col1\" class=\"data row5 col1\" >0.146479</td>\n",
" <td id=\"T_d129d_row5_col2\" class=\"data row5 col2\" >0.581718</td>\n",
" <td id=\"T_d129d_row5_col3\" class=\"data row5 col3\" >0.078698</td>\n",
" <td id=\"T_d129d_row5_col4\" class=\"data row5 col4\" >0.098279</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_d129d_row6_col0\" class=\"data row6 col0\" >0.176923</td>\n",
" <td id=\"T_d129d_row6_col1\" class=\"data row6 col1\" >0.151713</td>\n",
" <td id=\"T_d129d_row6_col2\" class=\"data row6 col2\" >0.562024</td>\n",
" <td id=\"T_d129d_row6_col3\" class=\"data row6 col3\" >0.071080</td>\n",
" <td id=\"T_d129d_row6_col4\" class=\"data row6 col4\" >0.079232</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_d129d_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_d129d_row7_col0\" class=\"data row7 col0\" >0.250000</td>\n",
" <td id=\"T_d129d_row7_col1\" class=\"data row7 col1\" >0.057143</td>\n",
" <td id=\"T_d129d_row7_col2\" class=\"data row7 col2\" >0.554978</td>\n",
" <td id=\"T_d129d_row7_col3\" class=\"data row7 col3\" >0.000000</td>\n",
" <td id=\"T_d129d_row7_col4\" class=\"data row7 col4\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dc06c64740>"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"'Error items count: 336'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Predicted</th>\n",
" <th>Name</th>\n",
" <th>Networth</th>\n",
" <th>Country</th>\n",
" <th>Source</th>\n",
" <th>Industry</th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>7</td>\n",
" <td>4</td>\n",
" <td>Sergey Brin</td>\n",
" <td>107.0</td>\n",
" <td>United States</td>\n",
" <td>Google</td>\n",
" <td>Technology</td>\n",
" <td>40-50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>9</td>\n",
" <td>3</td>\n",
" <td>Steve Ballmer</td>\n",
" <td>91.4</td>\n",
" <td>United States</td>\n",
" <td>Microsoft</td>\n",
" <td>Technology</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>13</td>\n",
" <td>3</td>\n",
" <td>Carlos Slim Helu &amp; family</td>\n",
" <td>81.2</td>\n",
" <td>Mexico</td>\n",
" <td>telecom</td>\n",
" <td>Telecom</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>15</td>\n",
" <td>3</td>\n",
" <td>Mark Zuckerberg</td>\n",
" <td>67.3</td>\n",
" <td>United States</td>\n",
" <td>Facebook</td>\n",
" <td>Technology</td>\n",
" <td>30-40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>23</td>\n",
" <td>5</td>\n",
" <td>Amancio Ortega</td>\n",
" <td>59.6</td>\n",
" <td>Spain</td>\n",
" <td>Zara</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2586</th>\n",
" <td>2578</td>\n",
" <td>3</td>\n",
" <td>Roy Chi Ping Chung</td>\n",
" <td>1.0</td>\n",
" <td>Hong Kong</td>\n",
" <td>manufacturing</td>\n",
" <td>Manufacturing</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2588</th>\n",
" <td>2578</td>\n",
" <td>3</td>\n",
" <td>Ronald Clarke</td>\n",
" <td>1.0</td>\n",
" <td>United States</td>\n",
" <td>payments technology</td>\n",
" <td>Technology</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2591</th>\n",
" <td>2578</td>\n",
" <td>5</td>\n",
" <td>Sefik Yilmaz Dizdar</td>\n",
" <td>1.0</td>\n",
" <td>Turkey</td>\n",
" <td>fashion retail</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2593</th>\n",
" <td>2578</td>\n",
" <td>6</td>\n",
" <td>Larry Fink</td>\n",
" <td>1.0</td>\n",
" <td>United States</td>\n",
" <td>money management</td>\n",
" <td>Finance &amp; Investments</td>\n",
" <td>60-70</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2596</th>\n",
" <td>2578</td>\n",
" <td>5</td>\n",
" <td>Nari Genomal</td>\n",
" <td>1.0</td>\n",
" <td>Philippines</td>\n",
" <td>apparel</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" <td>80+</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>336 rows × 8 columns</p>\n",
"</div>"
],
"text/plain": [
" Rank Predicted Name Networth Country \\\n",
"6 7 4 Sergey Brin 107.0 United States \n",
"8 9 3 Steve Ballmer 91.4 United States \n",
"12 13 3 Carlos Slim Helu & family 81.2 Mexico \n",
"14 15 3 Mark Zuckerberg 67.3 United States \n",
"22 23 5 Amancio Ortega 59.6 Spain \n",
"... ... ... ... ... ... \n",
"2586 2578 3 Roy Chi Ping Chung 1.0 Hong Kong \n",
"2588 2578 3 Ronald Clarke 1.0 United States \n",
"2591 2578 5 Sefik Yilmaz Dizdar 1.0 Turkey \n",
"2593 2578 6 Larry Fink 1.0 United States \n",
"2596 2578 5 Nari Genomal 1.0 Philippines \n",
"\n",
" Source Industry Age_category \n",
"6 Google Technology 40-50 \n",
"8 Microsoft Technology 60-70 \n",
"12 telecom Telecom 80+ \n",
"14 Facebook Technology 30-40 \n",
"22 Zara Fashion & Retail 80+ \n",
"... ... ... ... \n",
"2586 manufacturing Manufacturing 60-70 \n",
"2588 payments technology Technology 60-70 \n",
"2591 fashion retail Fashion & Retail 80+ \n",
"2593 money management Finance & Investments 60-70 \n",
"2596 apparel Fashion & Retail 80+ \n",
"\n",
"[336 rows x 8 columns]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Age_category\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Многовато..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Name</th>\n",
" <th>Networth</th>\n",
" <th>Country</th>\n",
" <th>Source</th>\n",
" <th>Industry</th>\n",
" <th>Age_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>438</td>\n",
" <td>Ruan Liping</td>\n",
" <td>5.8</td>\n",
" <td>Hong Kong</td>\n",
" <td>power strips</td>\n",
" <td>Manufacturing</td>\n",
" <td>50-60</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Rank Name Networth Country Source Industry \\\n",
"450 438 Ruan Liping 5.8 Hong Kong power strips Manufacturing \n",
"\n",
" Age_category \n",
"450 50-60 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prepocessing_num__Networth</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>0.289255</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 860 columns</p>\n",
"</div>"
],
"text/plain": [
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
"450 0.289255 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"450 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"450 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"450 1.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"450 0.0 \n",
"\n",
"[1 rows x 860 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"'predicted: 3 (proba: [0.00172036 0.04303104 0.02714323 0.36848158 0.19524859 0.2037863\\n 0.1605889 ])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 3'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
2024-11-29 01:37:09 +04:00
"execution_count": null,
2024-11-16 01:14:29 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 10,\n",
" 'model__max_features': 2,\n",
" 'model__n_estimators': 250}"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
2024-11-16 01:24:24 +04:00
"execution_count": 69,
2024-11-16 01:14:29 +04:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=9,\n",
" criterion=\"gini\",\n",
" max_depth=10,\n",
" max_features=2,\n",
" n_estimators=250,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
2024-11-16 01:24:24 +04:00
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)\n",
"result[\"preds\"] = np.argmax(y_test_probs, axis=1)\n",
2024-11-16 01:14:29 +04:00
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"],average=\"macro\")\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"], average=\"macro\")\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"], multi_class=\"ovr\")\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"], average=\"macro\")\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
2024-11-16 01:24:24 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_002a7_row0_col0, #T_002a7_row0_col1, #T_002a7_row0_col2, #T_002a7_row0_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_002a7_row0_col4, #T_002a7_row0_col5, #T_002a7_row0_col6, #T_002a7_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_002a7_row1_col0, #T_002a7_row1_col1, #T_002a7_row1_col2, #T_002a7_row1_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_002a7_row1_col4, #T_002a7_row1_col5, #T_002a7_row1_col6, #T_002a7_row1_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_002a7\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_002a7_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_002a7_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_002a7_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_002a7_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_002a7_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_002a7_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_002a7_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_002a7_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_002a7_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_002a7_row0_col0\" class=\"data row0 col0\" >0.581578</td>\n",
" <td id=\"T_002a7_row0_col1\" class=\"data row0 col1\" >0.236539</td>\n",
" <td id=\"T_002a7_row0_col2\" class=\"data row0 col2\" >0.735419</td>\n",
" <td id=\"T_002a7_row0_col3\" class=\"data row0 col3\" >0.246556</td>\n",
" <td id=\"T_002a7_row0_col4\" class=\"data row0 col4\" >0.627404</td>\n",
" <td id=\"T_002a7_row0_col5\" class=\"data row0 col5\" >0.288462</td>\n",
" <td id=\"T_002a7_row0_col6\" class=\"data row0 col6\" >0.599765</td>\n",
" <td id=\"T_002a7_row0_col7\" class=\"data row0 col7\" >0.231541</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_002a7_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_002a7_row1_col0\" class=\"data row1 col0\" >0.181388</td>\n",
" <td id=\"T_002a7_row1_col1\" class=\"data row1 col1\" >0.035714</td>\n",
" <td id=\"T_002a7_row1_col2\" class=\"data row1 col2\" >0.157692</td>\n",
" <td id=\"T_002a7_row1_col3\" class=\"data row1 col3\" >0.142857</td>\n",
" <td id=\"T_002a7_row1_col4\" class=\"data row1 col4\" >0.306250</td>\n",
" <td id=\"T_002a7_row1_col5\" class=\"data row1 col5\" >0.250000</td>\n",
" <td id=\"T_002a7_row1_col6\" class=\"data row1 col6\" >0.090702</td>\n",
" <td id=\"T_002a7_row1_col7\" class=\"data row1 col7\" >0.057143</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dc027746e0>"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_6486e_row0_col0, #T_6486e_row0_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_6486e_row0_col2, #T_6486e_row1_col3, #T_6486e_row1_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6486e_row0_col3, #T_6486e_row0_col4, #T_6486e_row1_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_6486e_row1_col0, #T_6486e_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_6486e\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_6486e_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_6486e_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_6486e_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_6486e_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_6486e_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_6486e_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_6486e_row0_col0\" class=\"data row0 col0\" >0.288462</td>\n",
" <td id=\"T_6486e_row0_col1\" class=\"data row0 col1\" >0.231541</td>\n",
" <td id=\"T_6486e_row0_col2\" class=\"data row0 col2\" >0.599541</td>\n",
" <td id=\"T_6486e_row0_col3\" class=\"data row0 col3\" >0.126828</td>\n",
" <td id=\"T_6486e_row0_col4\" class=\"data row0 col4\" >0.129917</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_6486e_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_6486e_row1_col0\" class=\"data row1 col0\" >0.250000</td>\n",
" <td id=\"T_6486e_row1_col1\" class=\"data row1 col1\" >0.057143</td>\n",
" <td id=\"T_6486e_row1_col2\" class=\"data row1 col2\" >0.605446</td>\n",
" <td id=\"T_6486e_row1_col3\" class=\"data row1 col3\" >0.000000</td>\n",
" <td id=\"T_6486e_row1_col4\" class=\"data row1 col4\" >0.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1dc02777b60>"
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA30AAAGsCAYAAABpd84aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACrbElEQVR4nOzdd3gU5drH8e/upmx6rySBUEOVIkIApYgERAQBFURFUDlyAKVZ8FAERWwIooAH5QVREMWCoggiSJOAdJESWiCBNEJIQvq2948cFtaEErLJTnbvz3XNpTszO/lll+y9zzzPPKMymUwmhBBCCCGEEELYJbWtAwghhBBCCCGEqDrS6BNCCCGEEEIIOyaNPiGEEEIIIYSwY9LoE0IIIYQQQgg7Jo0+IYQQQgghhLBj0ugTQgghhBBCCDsmjT4hhBBCCCGEsGPS6BNCCCGEEEIIO+Zk6wBCCCGUpaioiJKSEqsdz8XFBa1Wa7XjCSGEEBUhdU0afUIIIa5RVFREdG1P0jIMVjtmaGgoiYmJNa5ACiGEqPmkrpWSRp8QQgizkpIS0jIMJO6tjbdX5a8AyL1sJLrNWUpKSmpUcRRCCGEfpK6VkkafEEKIMry91FYpjkIIIYQSOHpdk0afEEKIMgwmIwaTdY4jhBBC2Jqj1zVp9AkhhCjDiAkjla+O1jiGEEIIUVmOXtcct49TCCGEEEIIIRyA9PQJIYQow4gRawxgsc5RhBBCiMpx9LomjT4hhBBlGEwmDKbKD2GxxjGEEEKIynL0uibDO4UQQgghhBDCjklPnxBCiDIc/YJ3IYQQ9sXR65o0+oQQQpRhxITBgYujEEII++LodU2GdwohhFCE1157DZVKZbHExMSYtxcVFTFq1CgCAgLw9PRkwIABpKen2zCxEEIIUTNIo08IIUQZV4bBWGOpiKZNm5Kammpetm/fbt42btw41qxZw6pVq9iyZQspKSn079/f2r+6EEIIO2SrurZ161b69OlDeHg4KpWK1atXX3ff5557DpVKxdy5cy3WZ2VlMWTIELy9vfH19eXpp58mLy+vQjmk0SeEEEIxnJycCA0NNS+BgYEA5OTksHjxYt5//326detGmzZtWLJkCTt27GDnzp02Ti2EEEKULz8/nzvuuIP58+ffcL/vv/+enTt3Eh4eXmbbkCFDOHz4MBs2bOCnn35i69atjBgxokI55Jo+IYQQZVh7auvc3FyL9a6urri6upbZ/8SJE4SHh6PVaomNjWXWrFlERUWxd+9edDod3bt3N+8bExNDVFQU8fHxtG/fvtJZhRBC2C9b3bKhV69e9OrV64b7nD9/njFjxrB+/Xp69+5tse3o0aOsW7eO3bt3c+eddwLw4Ycfcv/99/Pee++V20gsj/T0CSGEKMNoxQUgMjISHx8f8zJr1qwyP7Ndu3YsXbqUdevWsXDhQhITE7n77ru5fPkyaWlpuLi44Ovra/GckJAQ0tLSrP3rCyGEsDPWrmu5ubkWS3Fx8e3lMhp54oknePHFF2natGmZ7fHx8fj6+pobfADdu3dHrVaza9euW/450tMnhBCiyiUnJ+Pt7W1+XF4v37VnQlu0aEG7du2oXbs2X3/9NW5ubtWSUwghhLgVkZGRFo+nTZvGa6+9VuHjvP322zg5OfH888+Xuz0tLY3g4GCLdU5OTvj7+1fopKc0+oQQQpRhsNLU1leO4e3tbdHouxW+vr40bNiQkydPct9991FSUkJ2drZFb196ejqhoaGVzimEEMK+Wbuu3crJzJvZu3cvH3zwAfv27UOlUlU6243I8E4hhBBlGEzWW25XXl4ep06dIiwsjDZt2uDs7MzGjRvN2xMSEkhKSiI2NtYKv7EQQgh7Zu26duVk5pXldhp927ZtIyMjg6ioKJycnHBycuLs2bNMmDCBOnXqABAaGkpGRobF8/R6PVlZWRU66Sk9fUIIIRRh4sSJ9OnTh9q1a5OSksK0adPQaDQMHjwYHx8fnn76acaPH4+/vz/e3t6MGTOG2NhYmcRFCCFEjfTEE09YTFAGEBcXxxNPPMGwYcMAiI2NJTs7m71799KmTRsANm3ahNFopF27drf8s6TRJ4QQooxrL1av7HFu1blz5xg8eDAXL14kKCiITp06sXPnToKCggCYM2cOarWaAQMGUFxcTFxcHAsWLLBCSiGEEPbOFnUNSketnDx50vw4MTGRAwcO4O/vT1RUFAEBARb7Ozs7ExoaSqNGjQBo3LgxPXv25Nlnn+Xjjz9Gp9MxevRoBg0adMszd4I0+oQQQpTDiAoDlb++wFiBY6xcufKG27VaLfPnz7/pvY6EEEKIf7JFXQPYs2cPXbt2NT8eP348AEOHDmXp0qW3dIzly5czevRo7r33XvPJz3nz5lUohzT6hBBCCCGEEKIKdOnSBVMF7u135syZMuv8/f1ZsWJFpXJIo08IIUQZRlPpYo3jCCGEELbm6HVNZu8UQgghhBBCCDsmPX1CCCHKMFjp2gdrHEMIIYSoLEeva9LoE0IIUYajF0chhBD2xdHrmgzvFEIIIYQQQgg7Jj19QgghyjCaVBhNVpja2grHEEIIISrL0euaNPqEEEKU4ejDYIQQQtgXR69rMrxTCCGEEEIIIeyY9PQJIYQow4AagxXOCxqskEUIIYSoLEeva9LoE0IIUYbJStc+mGrotQ9CCCHsi6PXNRneKYQQQgghhBB2THr6hBBClOHoF7wLIYSwL45e16TRJ4QQogyDSY3BZIVrH0xWCCOEEEJUkqPXNRneKYQQQgghhBB2THr6hBBClGFEhdEK5wWN1NBTokIIIeyKo9c16ekTQgghhBBCCDsmPX1CCCHKcPQL3oUQQtgXR69r0ugTQghRhvUueK+Zw2CEEELYF0evazK8UwghhBBCCCHsmPT0CSGEKKP0gvfKD2GxxjGEEEKIynL0uiaNPiGEEGUYUWNw4FnOhBBC2BdHr2syvFMIIYQQQggh7Jj09AkhhCjD0S94F0IIYV8cva5Jo08IIUQZRtQOfRNbIYQQ9sXR65oM7xRCCCGEEEIIOyY9fUIIIcowmFQYTFa4ia0VjiGEEEJUlqPXNenpE0IIIYQQQgg7Jj19QgghyjBYaWprQw299kEIIYR9cfS6Jo0+IYQQZRhNaoxWmOXMWENnORNCCGFfHL2uyfBOIYQQQgghhLBj0tMnhBCiDEcfBiOEEMK+OHpdk0afEEKIMoxYZ4YyY+WjCCGEEJXm6HVNhncKIYQQQgghhB2Tnj4hhBBlGFFjtMJ5QWscQwghhKgsR69r0ugTQghRhsGkxmCFWc6scQwhhBCishy9rtXM1EIIIYQQQgghbon09AkhhCjDiAoj1rjgvfLHEEIIISrL0euaNPqEEEKU4ejDYIQQQtgXR69rNTO1EEIIIYQQQohbIj19QgghyrDeTWzl3KIQQgjbc/S6VjNTCyGEEEIIIYS4JdLTp1BGo5GUlBS8vLxQqWrmBaNCiOplMpm4fPky4eHhqNWVO6dnNKkwmqxwwbsVjiHsg9Q1IURFSV2zHmn0KVRKSgqRkZG2jiGEqIGSk5OJiIio1DGMVhoGU1NvYiusT+qaEOJ2SV2rPGn0KZSXlxcAnbgfJ5xtnEZYg8rZxdYRymXSldg6Qo2j1PdSb9KxTb/a/PkhhJJIXbM/Kidlfo006fW2jiCsRI+O7ayVumYFyvxrFeahL04446SS4mgPVAp9H00qk60j1DhKfS+vsMbQOaNJjdEK01Jb4xjCPkhdsz8qlTK/Rppk+LD9+N9XlJpc17Zu3cq7777L3r17SU1N5fvvv6dfv34A6HQ6Jk+ezNq1azl9+jQ+Pj50796dt956i/DwcPMxsrKyGDNmDGvWrEGtVjNgwAA++OADPD09bzmHVGMhhBBlGFBZbRFCCCFszVZ1LT8/nzvuuIP58+eX2VZQUMC+ffuYMmUK+/bt47vvviMhIYEHH3zQYr8hQ4Zw+PBhNmzYwE8//cTWrVsZMWJEhXIo8xSNEEIIIYQQQtRwvXr1olevXuVu8/HxYcOGDRbrPvroI+666y6SkpKIiori6NGjrFu3jt27d3P
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Under 30\", \"30-40\", \"40-50\", \"50-60\", \"60-70\", \"70-80\", \"80+\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
2024-11-29 00:30:06 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Задача регрессии"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
"y = df['Networth'] # Целевая переменная для регрессии\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
2024-11-29 01:37:09 +04:00
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rank</th>\n",
" <th>Name</th>\n",
" <th>Networth</th>\n",
" <th>Age</th>\n",
" <th>Country</th>\n",
" <th>Source</th>\n",
" <th>Industry</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Elon Musk</td>\n",
" <td>219.0</td>\n",
" <td>50</td>\n",
" <td>United States</td>\n",
" <td>Tesla, SpaceX</td>\n",
" <td>Automotive</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>Jeff Bezos</td>\n",
" <td>171.0</td>\n",
" <td>58</td>\n",
" <td>United States</td>\n",
" <td>Amazon</td>\n",
" <td>Technology</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>Bernard Arnault &amp; family</td>\n",
" <td>158.0</td>\n",
" <td>73</td>\n",
" <td>France</td>\n",
" <td>LVMH</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>Bill Gates</td>\n",
" <td>129.0</td>\n",
" <td>66</td>\n",
" <td>United States</td>\n",
" <td>Microsoft</td>\n",
" <td>Technology</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Warren Buffett</td>\n",
" <td>118.0</td>\n",
" <td>91</td>\n",
" <td>United States</td>\n",
" <td>Berkshire Hathaway</td>\n",
" <td>Finance &amp; Investments</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2595</th>\n",
" <td>2578</td>\n",
" <td>Jorge Gallardo Ballart</td>\n",
" <td>1.0</td>\n",
" <td>80</td>\n",
" <td>Spain</td>\n",
" <td>pharmaceuticals</td>\n",
" <td>Healthcare</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2596</th>\n",
" <td>2578</td>\n",
" <td>Nari Genomal</td>\n",
" <td>1.0</td>\n",
" <td>82</td>\n",
" <td>Philippines</td>\n",
" <td>apparel</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2597</th>\n",
" <td>2578</td>\n",
" <td>Ramesh Genomal</td>\n",
" <td>1.0</td>\n",
" <td>71</td>\n",
" <td>Philippines</td>\n",
" <td>apparel</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2598</th>\n",
" <td>2578</td>\n",
" <td>Sunder Genomal</td>\n",
" <td>1.0</td>\n",
" <td>68</td>\n",
" <td>Philippines</td>\n",
" <td>garments</td>\n",
" <td>Fashion &amp; Retail</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2599</th>\n",
" <td>2578</td>\n",
" <td>Horst-Otto Gerberding</td>\n",
" <td>1.0</td>\n",
" <td>69</td>\n",
" <td>Germany</td>\n",
" <td>flavors and fragrances</td>\n",
" <td>Food &amp; Beverage</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2600 rows × 7 columns</p>\n",
"</div>"
],
"text/plain": [
" Rank Name Networth Age Country \\\n",
"0 1 Elon Musk 219.0 50 United States \n",
"1 2 Jeff Bezos 171.0 58 United States \n",
"2 3 Bernard Arnault & family 158.0 73 France \n",
"3 4 Bill Gates 129.0 66 United States \n",
"4 5 Warren Buffett 118.0 91 United States \n",
"... ... ... ... ... ... \n",
"2595 2578 Jorge Gallardo Ballart 1.0 80 Spain \n",
"2596 2578 Nari Genomal 1.0 82 Philippines \n",
"2597 2578 Ramesh Genomal 1.0 71 Philippines \n",
"2598 2578 Sunder Genomal 1.0 68 Philippines \n",
"2599 2578 Horst-Otto Gerberding 1.0 69 Germany \n",
"\n",
" Source Industry \n",
"0 Tesla, SpaceX Automotive \n",
"1 Amazon Technology \n",
"2 LVMH Fashion & Retail \n",
"3 Microsoft Technology \n",
"4 Berkshire Hathaway Finance & Investments \n",
"... ... ... \n",
"2595 pharmaceuticals Healthcare \n",
"2596 apparel Fashion & Retail \n",
"2597 apparel Fashion & Retail \n",
"2598 garments Fashion & Retail \n",
"2599 flavors and fragrances Food & Beverage \n",
"\n",
"[2600 rows x 7 columns]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-11-29 00:30:06 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prepocessing_num__Age</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>582</th>\n",
" <td>-0.109934</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48</th>\n",
" <td>1.079079</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1772</th>\n",
" <td>1.004766</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>964</th>\n",
" <td>-0.407187</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2213</th>\n",
" <td>1.302019</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1638</th>\n",
" <td>1.227706</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1095</th>\n",
" <td>0.856139</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1130</th>\n",
" <td>0.781826</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1294</th>\n",
" <td>0.335946</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0.558886</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2080 rows × 855 columns</p>\n",
"</div>"
],
"text/plain": [
" prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
"582 -0.109934 0.0 \n",
"48 1.079079 0.0 \n",
"1772 1.004766 0.0 \n",
"964 -0.407187 0.0 \n",
"2213 1.302019 0.0 \n",
"... ... ... \n",
"1638 1.227706 0.0 \n",
"1095 0.856139 0.0 \n",
"1130 0.781826 0.0 \n",
"1294 0.335946 0.0 \n",
"860 0.558886 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 1.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 1.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 1.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"582 ... 0.0 \n",
"48 ... 0.0 \n",
"1772 ... 0.0 \n",
"964 ... 0.0 \n",
"2213 ... 0.0 \n",
"... ... ... \n",
"1638 ... 0.0 \n",
"1095 ... 0.0 \n",
"1130 ... 0.0 \n",
"1294 ... 0.0 \n",
"860 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"582 0.0 \n",
"48 1.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 1.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 1.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"582 1.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 1.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
"[2080 rows x 855 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import pandas as pd\n",
"\n",
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = []\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "code",
2024-11-29 01:37:09 +04:00
"execution_count": 13,
2024-11-29 00:30:06 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-29 00:53:22 +04:00
"Training LinearRegression...\n"
2024-11-29 00:30:06 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-11-29 00:53:22 +04:00
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 1 is smaller than n_iter=10. Running 1 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training RandomForestRegressor...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
2024-11-29 00:30:06 +04:00
" warnings.warn(\n"
]
},
{
2024-11-29 00:53:22 +04:00
"name": "stdout",
"output_type": "stream",
"text": [
"Training GradientBoostingRegressor...\n",
"\n",
"Model: LinearRegression\n",
"Best Params: {}\n",
"MAE: 18059903.80176681\n",
"RMSE: 411829080.6584508\n",
"R2: -7135788186375614.0\n",
"\n",
"Model: RandomForestRegressor\n",
2024-11-29 01:37:09 +04:00
"Best Params: {'model__n_estimators': 40, 'model__max_depth': 10}\n",
"MAE: 3.454630023161808\n",
"RMSE: 7.755775760541111\n",
"R2: -1.530803448377045\n",
2024-11-29 00:53:22 +04:00
"\n",
"Model: GradientBoostingRegressor\n",
2024-11-29 01:37:09 +04:00
"Best Params: {'model__n_estimators': 100, 'model__max_depth': 4, 'model__learning_rate': 0.4}\n",
"MAE: 3.585784679817764\n",
"RMSE: 10.312249036012052\n",
"R2: -3.474193004771121\n"
2024-11-29 00:53:22 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
2024-11-29 00:30:06 +04:00
]
}
],
"source": [
2024-11-29 00:53:22 +04:00
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.model_selection import GridSearchCV, RandomizedSearchCV\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"import matplotlib.pyplot as plt\n",
2024-11-29 00:30:06 +04:00
"\n",
2024-11-29 00:53:22 +04:00
"random_state = 42\n",
2024-11-29 00:30:06 +04:00
"\n",
"# Модели и параметры\n",
2024-11-29 00:53:22 +04:00
"models_regression = {\n",
" \"LinearRegression\": LinearRegression(),\n",
" \"RandomForestRegressor\": RandomForestRegressor(random_state=random_state),\n",
" \"GradientBoostingRegressor\": GradientBoostingRegressor(random_state=random_state)\n",
2024-11-29 00:30:06 +04:00
"}\n",
"\n",
2024-11-29 00:53:22 +04:00
"param_grids_regression = {\n",
" \"LinearRegression\": {},\n",
" \"RandomForestRegressor\": {\n",
2024-11-29 01:37:09 +04:00
" 'model__n_estimators': [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" 'model__max_depth': [None, 2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
2024-11-29 00:30:06 +04:00
" },\n",
2024-11-29 00:53:22 +04:00
" \"GradientBoostingRegressor\": {\n",
2024-11-29 01:37:09 +04:00
" 'model__n_estimators': [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5],\n",
" 'model__max_depth': [2, 3, 4, 5, 6, 7, 8, 9 ,10]\n",
2024-11-29 00:30:06 +04:00
" }\n",
"}\n",
"\n",
"# Результаты\n",
2024-11-29 00:53:22 +04:00
"results_regression = {}\n",
2024-11-29 00:30:06 +04:00
"\n",
"# Перебор моделей\n",
2024-11-29 00:53:22 +04:00
"for name, model in models_regression.items():\n",
2024-11-29 00:30:06 +04:00
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
2024-11-29 00:53:22 +04:00
" param_grid = param_grids_regression[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='neg_mean_absolute_error', n_jobs=-1)\n",
2024-11-29 00:30:06 +04:00
" grid_search.fit(X_train, y_train)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test)\n",
"\n",
" # Метрики\n",
2024-11-29 00:53:22 +04:00
" mae = mean_absolute_error(y_test, y_pred)\n",
" rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n",
" r2 = r2_score(y_test, y_pred)\n",
2024-11-29 00:30:06 +04:00
"\n",
" # Сохранение результатов\n",
2024-11-29 00:53:22 +04:00
" results_regression[name] = {\n",
2024-11-29 00:30:06 +04:00
" \"Best Params\": grid_search.best_params_,\n",
2024-11-29 00:53:22 +04:00
" \"MAE\": mae,\n",
" \"RMSE\": rmse,\n",
" \"R2\": r2\n",
2024-11-29 00:30:06 +04:00
" }\n",
"\n",
"# Печать результатов\n",
2024-11-29 00:53:22 +04:00
"for name, metrics in results_regression.items():\n",
2024-11-29 00:30:06 +04:00
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
]
2024-11-29 00:53:22 +04:00
},
{
"cell_type": "code",
2024-11-29 01:37:09 +04:00
"execution_count": 14,
2024-11-29 00:53:22 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
2024-11-29 01:37:09 +04:00
"#T_5e893_row0_col0, #T_5e893_row0_col1, #T_5e893_row1_col0, #T_5e893_row1_col1 {\n",
2024-11-29 00:53:22 +04:00
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-29 01:37:09 +04:00
"#T_5e893_row0_col2, #T_5e893_row1_col2 {\n",
2024-11-29 00:53:22 +04:00
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-29 01:37:09 +04:00
"#T_5e893_row2_col0, #T_5e893_row2_col1 {\n",
2024-11-29 00:53:22 +04:00
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
2024-11-29 01:37:09 +04:00
"#T_5e893_row2_col2 {\n",
2024-11-29 00:53:22 +04:00
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
2024-11-29 01:37:09 +04:00
"<table id=\"T_5e893\">\n",
2024-11-29 00:53:22 +04:00
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
2024-11-29 01:37:09 +04:00
" <th id=\"T_5e893_level0_col0\" class=\"col_heading level0 col0\" >MAE</th>\n",
" <th id=\"T_5e893_level0_col1\" class=\"col_heading level0 col1\" >RMSE</th>\n",
" <th id=\"T_5e893_level0_col2\" class=\"col_heading level0 col2\" >R2</th>\n",
2024-11-29 00:53:22 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-29 01:37:09 +04:00
" <th id=\"T_5e893_level0_row0\" class=\"row_heading level0 row0\" >RandomForestRegressor</th>\n",
" <td id=\"T_5e893_row0_col0\" class=\"data row0 col0\" >3.454630</td>\n",
" <td id=\"T_5e893_row0_col1\" class=\"data row0 col1\" >7.755776</td>\n",
" <td id=\"T_5e893_row0_col2\" class=\"data row0 col2\" >-1.530803</td>\n",
2024-11-29 00:53:22 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 01:37:09 +04:00
" <th id=\"T_5e893_level0_row1\" class=\"row_heading level0 row1\" >GradientBoostingRegressor</th>\n",
" <td id=\"T_5e893_row1_col0\" class=\"data row1 col0\" >3.585785</td>\n",
" <td id=\"T_5e893_row1_col1\" class=\"data row1 col1\" >10.312249</td>\n",
" <td id=\"T_5e893_row1_col2\" class=\"data row1 col2\" >-3.474193</td>\n",
2024-11-29 00:53:22 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 01:37:09 +04:00
" <th id=\"T_5e893_level0_row2\" class=\"row_heading level0 row2\" >LinearRegression</th>\n",
" <td id=\"T_5e893_row2_col0\" class=\"data row2 col0\" >18059903.801767</td>\n",
" <td id=\"T_5e893_row2_col1\" class=\"data row2 col1\" >411829080.658451</td>\n",
" <td id=\"T_5e893_row2_col2\" class=\"data row2 col2\" >-7135788186375614.000000</td>\n",
2024-11-29 00:53:22 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
2024-11-29 01:37:09 +04:00
"<pandas.io.formats.style.Styler at 0x182a6043ef0>"
2024-11-29 00:53:22 +04:00
]
},
2024-11-29 01:37:09 +04:00
"execution_count": 14,
2024-11-29 00:53:22 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Импортируем pandas для работы с таблицами\n",
"import pandas as pd\n",
"\n",
"# Формируем таблицу метрик\n",
"reg_metrics = pd.DataFrame.from_dict(results_regression, orient=\"index\")[\n",
" [\"MAE\", \"RMSE\", \"R2\"]\n",
"]\n",
"\n",
"# Визуализация результатов с помощью стилизации\n",
"styled_metrics = (\n",
" reg_metrics.sort_values(by=\"RMSE\")\n",
" .style.background_gradient(cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE\", \"MAE\"])\n",
" .background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"R2\"])\n",
")\n",
"\n",
"styled_metrics"
]
2024-11-29 01:37:09 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Шикарный вывод: по стране, возрасту, сфере деятельности и источнику доходов невозможно предсказать состояние человека. Значит ли это, что кто угодно, где угодно, и в чём угодно может добиться успеха?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Классификация"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Rank Name Networth Country \\\n",
"0 1 Elon Musk 219.0 United States \n",
"1 2 Jeff Bezos 171.0 United States \n",
"2 3 Bernard Arnault & family 158.0 France \n",
"3 4 Bill Gates 129.0 United States \n",
"4 5 Warren Buffett 118.0 United States \n",
"\n",
" Source Industry Age_category \n",
"0 Tesla, SpaceX Automotive 50-60 \n",
"1 Amazon Technology 50-60 \n",
"2 LVMH Fashion & Retail 70-80 \n",
"3 Microsoft Technology 60-70 \n",
"4 Berkshire Hathaway Finance & Investments 80+ \n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"\n",
"bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n",
"labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n",
"\n",
"df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n",
"# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n",
"df.drop(columns=['Age'], inplace=True)\n",
"\n",
"# Просмотр результата\n",
"print(df.head())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training LogisticRegression...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n"
]
},
{
"ename": "ValueError",
"evalue": "\nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py\", line 3805, in get_loc\n return self._engine.get_loc(casted_key)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"index.pyx\", line 167, in pandas._libs.index.IndexEngine.get_loc\n File \"index.pyx\", line 196, in pandas._libs.index.IndexEngine.get_loc\n File \"pandas\\\\_libs\\\\hashtable_class_helper.pxi\", line 7081, in pandas._libs.hashtable.PyObjectHashTable.get_item\n File \"pandas\\\\_libs\\\\hashtable_class_helper.pxi\", line 7089, in pandas._libs.hashtable.PyObjectHashTable.get_item\nKeyError: 'Age'\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_indexing.py\", line 361, in _get_column_indices\n col_idx = all_columns.get_loc(col)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py\", line 3812, in get_loc\n raise KeyError(key) from err\nKeyError: 'Age'\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 469, in fit\n Xt = self._fit(X, y, routed_params)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 406, in _fit\n X, fitted_transformer = fit_transform_one_cached(\n ^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\joblib\\memory.py\", line 312, in __call__\n return self.func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 1310, in _fit_transform_one\n res = transformer.fit_transform(X, y, **params.get(\"fit_transform\", {}))\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_set_output.py\", line 316, in wrapped\n data_to_wrap = f(self, X, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\compose\\_column_transformer.py\", line 968, in fit_transform\n self._validate_column_callables(X)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\compose\\_column_transformer.py\", line 536, in _validate_column_callables\n transformer_to_input_indices[name] = _get_column_indices(X, columns)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_indexing.py\", line 369, in _get_column_indices\n raise ValueError(\"A given column is not a column of the dataframe\") from e\nValueError: A given column is not a column of the dataframe\n",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[18], line 48\u001b[0m\n\u001b[0;32m 46\u001b[0m param_grid \u001b[38;5;241m=\u001b[39m param_grids_classification[name]\n\u001b[0;32m 47\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m RandomizedSearchCV(pipeline, param_grid, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 48\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train_clf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train_clf\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 50\u001b[0m \u001b[38;5;66;03m# Лучшая модель\u001b[39;00m\n\u001b[0;32m 51\u001b[0m best_model \u001b[38;5;241m=\u001b[39m grid_search\u001b[38;5;241m.\u001b[39mbest_estimator_\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[1;34m(self, X, y, **params)\u001b[0m\n\u001b[0;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[0;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[0;32m 1015\u001b[0m )\n\u001b[0;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[0;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[0;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1960\u001b[0m, in \u001b[0;36mRandomizedSearchCV._run_search\u001b[1;34m(self, evaluate_candidates)\u001b[0m\n\u001b[0;32m 1958\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[0;32m 1959\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search n_iter candidates from param_distributions\"\"\"\u001b[39;00m\n\u001b[1;32m-> 1960\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1961\u001b[0m \u001b[43m \u001b[49m\u001b[43mParameterSampler\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1962\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_distributions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom_state\u001b[49m\n\u001b[0;32m 1963\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1964\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:996\u001b[0m, in \u001b[0;36mBaseSearchCV.fit.<locals>.evaluate_candidates\u001b[1;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[0;32m 989\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m!=\u001b[39m n_candidates \u001b[38;5;241m*\u001b[39m n_splits:\n\u001b[0;32m 990\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 991\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv.split and cv.get_n_splits returned \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minconsistent results. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 993\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msplits, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(n_splits, \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m n_candidates)\n\u001b[0;32m 994\u001b[0m )\n\u001b[1;32m--> 996\u001b[0m \u001b[43m_warn_or_raise_about_fit_failures\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;66;03m# For callable self.scoring, the return type is only know after\u001b[39;00m\n\u001b[0;32m 999\u001b[0m \u001b[38;5;66;03m# calling. If the return type is a dictionary, the error scores\u001b[39;00m\n\u001b[0;32m 1000\u001b[0m \u001b[38;5;66;03m# can now be inserted with the correct key. The type checking\u001b[39;00m\n\u001b[0;32m 1001\u001b[0m \u001b[38;5;66;03m# of out will be done in `_insert_error_scores`.\u001b[39;00m\n\u001b[0;32m 1002\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscoring):\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:529\u001b[0m, in \u001b[0;36m_warn_or_raise_about_fit_failures\u001b[1;34m(results, error_score)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_failed_fits \u001b[38;5;241m==\u001b[39m num_fits:\n\u001b[0;32m 523\u001b[0m all_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 524\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mAll the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt is very likely that your model is misconfigured.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 526\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou can try to debug the error by setting error_score=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 528\u001b[0m )\n\u001b[1;32m--> 529\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(all_fits_failed_message)\n\u001b[0;32m 531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 532\u001b[0m some_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 533\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mnum_failed_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed out of a total of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 534\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe score on these train-test partitions for these parameters\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 539\u001b[0m )\n",
"\u001b[1;31mValueError\u001b[0m: \nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py\", line 3805, in get_loc\n return self._engine.get_loc(casted_key)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"index.pyx\", line 167, in pandas._libs.index.IndexEngine.get_loc\n File \"index.pyx\", line 196, in pandas._libs.index.IndexEngine.get_loc\n File \"pandas\\\\_libs\\\\hashtable_class_helper.pxi\", line 7081, in pandas._libs.hashtable.PyObjectHashTable.get_item\n File \"pandas\\\\_libs\\\\hashtable_class_helper.pxi\", line 7089, in pandas._libs.hashtable.PyObjectHashTable.get_item\nKeyError: 'Age'\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_indexing.py\", line 361, in _get_column_indices\n col_idx = all_columns.get_loc(col)\n ^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\pandas\\core\\indexes\\base.py\", line 3812, in get_loc\n raise KeyError(key) from err\nKeyError: 'Age'\n\nThe above exception was the direct cause of the following exception:\n\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 469, in fit\n Xt = self._fit(X, y, routed_params)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 406, in _fit\n X, fitted_transformer = fit_transform_one_cached(\n ^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\joblib\\memory.py\", line 312, in __call__\n return self.func(*args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 1310, in _fit_transform_one\n res = transformer.fit_transform(X, y, **params.get(\"fit_transform\", {}))\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_set_output.py\", line 316, in wrapped\n data_to_wrap = f(self, X, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\compose\\_column_transformer.py\", line 968, in fit_transform\n self._validate_column_callables(X)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\compose\\_column_transformer.py\", line 536, in _validate_column_callables\n transformer_to_input_indices[name] = _get_column_indices(X, columns)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_indexing.py\", line 369, in _get_column_indices\n raise ValueError(\"A given column is not a column of the dataframe\") from e\nValueError: A given column is not a column of the dataframe\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
"\n",
"X = df.drop(columns=['Age_category','Rank ', 'Name']) # Признаки\n",
"# Целевая переменная для классификации\n",
"y_class = df['Age_category'] \n",
"\n",
"# Разделение данных\n",
"X_train_clf, X_test_clf, y_train_clf, y_test_clf = train_test_split(X, y_class, test_size=0.2, random_state=42)\n",
"\n",
"# Модели и параметры\n",
"models_classification = {\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"KNN\": KNeighborsClassifier()\n",
"}\n",
"\n",
"param_grids_classification = {\n",
" \"LogisticRegression\": {\n",
" 'model__C': [0.1, 1, 10]\n",
" },\n",
" \"RandomForestClassifier\": {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
" },\n",
" \"KNN\": {\n",
" 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
" 'model__weights': ['uniform', 'distance']\n",
" }\n",
"}\n",
"\n",
"# Результаты\n",
"results_classification = {}\n",
"\n",
"# Перебор моделей\n",
"for name, model in models_classification.items():\n",
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
" param_grid = param_grids_classification[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
" grid_search.fit(X_train_clf, y_train_clf)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test_clf)\n",
"\n",
" # Метрики\n",
" acc = accuracy_score(y_test_clf, y_pred)\n",
" f1 = f1_score(y_test_clf, y_pred)\n",
"\n",
" # Вычисление матрицы ошибок\n",
" c_matrix = confusion_matrix(y_test_clf, y_pred)\n",
"\n",
" # Сохранение результатов\n",
" results_classification[name] = {\n",
" \"Best Params\": grid_search.best_params_,\n",
" \"Accuracy\": acc,\n",
" \"F1 Score\": f1,\n",
" \"Confusion_matrix\": c_matrix\n",
" }\n",
"\n",
"# Печать результатов\n",
"for name, metrics in results_classification.items():\n",
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
]
2024-11-15 22:35:48 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}