1493 lines
68 KiB
Plaintext
1493 lines
68 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
|
||
"print(df.columns)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Определим бизнес цели:\n",
|
||
"## 1- Прогнозирование состояния миллиардера(регрессия)\n",
|
||
"## 2- Прогнозирование возраста миллиардера(классификация)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Подготовим данные: категоризируем колонку age"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Rank 0\n",
|
||
"Name 0\n",
|
||
"Networth 0\n",
|
||
"Age 0\n",
|
||
"Country 0\n",
|
||
"Source 0\n",
|
||
"Industry 0\n",
|
||
"dtype: int64\n",
|
||
"\n",
|
||
"Rank False\n",
|
||
"Name False\n",
|
||
"Networth False\n",
|
||
"Age False\n",
|
||
"Country False\n",
|
||
"Source False\n",
|
||
"Industry False\n",
|
||
"dtype: bool\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(df.isnull().sum())\n",
|
||
"\n",
|
||
"print()\n",
|
||
"\n",
|
||
"# Есть ли пустые значения признаков\n",
|
||
"print(df.isnull().any())\n",
|
||
"\n",
|
||
"print()\n",
|
||
"\n",
|
||
"# Процент пустых значений признаков\n",
|
||
"for i in df.columns:\n",
|
||
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
|
||
" if null_rate > 0:\n",
|
||
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" Rank Name Networth Country \\\n",
|
||
"0 1 Elon Musk 219.0 United States \n",
|
||
"1 2 Jeff Bezos 171.0 United States \n",
|
||
"2 3 Bernard Arnault & family 158.0 France \n",
|
||
"3 4 Bill Gates 129.0 United States \n",
|
||
"4 5 Warren Buffett 118.0 United States \n",
|
||
"\n",
|
||
" Source Industry Age_category \n",
|
||
"0 Tesla, SpaceX Automotive 50-60 \n",
|
||
"1 Amazon Technology 50-60 \n",
|
||
"2 LVMH Fashion & Retail 70-80 \n",
|
||
"3 Microsoft Technology 60-70 \n",
|
||
"4 Berkshire Hathaway Finance & Investments 80+ \n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"\n",
|
||
"bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n",
|
||
"labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n",
|
||
"\n",
|
||
"df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n",
|
||
"# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n",
|
||
"df.drop(columns=['Age'], inplace=True)\n",
|
||
"\n",
|
||
"# Просмотр результата\n",
|
||
"print(df.head())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Rank</th>\n",
|
||
" <th>Name</th>\n",
|
||
" <th>Networth</th>\n",
|
||
" <th>Country</th>\n",
|
||
" <th>Source</th>\n",
|
||
" <th>Industry</th>\n",
|
||
" <th>Age_category</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1909</th>\n",
|
||
" <td>1818</td>\n",
|
||
" <td>Tran Ba Duong & family</td>\n",
|
||
" <td>1.6</td>\n",
|
||
" <td>Vietnam</td>\n",
|
||
" <td>automotive</td>\n",
|
||
" <td>Automotive</td>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2099</th>\n",
|
||
" <td>2076</td>\n",
|
||
" <td>Mark Dixon</td>\n",
|
||
" <td>1.4</td>\n",
|
||
" <td>United Kingdom</td>\n",
|
||
" <td>office real estate</td>\n",
|
||
" <td>Real Estate</td>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1392</th>\n",
|
||
" <td>1341</td>\n",
|
||
" <td>Yingzhuo Xu</td>\n",
|
||
" <td>2.3</td>\n",
|
||
" <td>China</td>\n",
|
||
" <td>agribusiness</td>\n",
|
||
" <td>Food & Beverage</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>627</th>\n",
|
||
" <td>622</td>\n",
|
||
" <td>Bruce Flatt</td>\n",
|
||
" <td>4.6</td>\n",
|
||
" <td>Canada</td>\n",
|
||
" <td>money management</td>\n",
|
||
" <td>Finance & Investments</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>527</th>\n",
|
||
" <td>523</td>\n",
|
||
" <td>Li Liangbin</td>\n",
|
||
" <td>5.2</td>\n",
|
||
" <td>China</td>\n",
|
||
" <td>lithium</td>\n",
|
||
" <td>Manufacturing</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>85</td>\n",
|
||
" <td>Theo Albrecht, Jr. & family</td>\n",
|
||
" <td>18.7</td>\n",
|
||
" <td>Germany</td>\n",
|
||
" <td>Aldi, Trader Joe's</td>\n",
|
||
" <td>Fashion & Retail</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>633</th>\n",
|
||
" <td>622</td>\n",
|
||
" <td>Tony Tamer</td>\n",
|
||
" <td>4.6</td>\n",
|
||
" <td>United States</td>\n",
|
||
" <td>private equity</td>\n",
|
||
" <td>Finance & Investments</td>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>922</th>\n",
|
||
" <td>913</td>\n",
|
||
" <td>Bob Gaglardi</td>\n",
|
||
" <td>3.3</td>\n",
|
||
" <td>Canada</td>\n",
|
||
" <td>hotels</td>\n",
|
||
" <td>Real Estate</td>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2178</th>\n",
|
||
" <td>2076</td>\n",
|
||
" <td>Eugene Wu</td>\n",
|
||
" <td>1.4</td>\n",
|
||
" <td>Taiwan</td>\n",
|
||
" <td>finance</td>\n",
|
||
" <td>Finance & Investments</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>415</th>\n",
|
||
" <td>411</td>\n",
|
||
" <td>Leonard Stern</td>\n",
|
||
" <td>6.2</td>\n",
|
||
" <td>United States</td>\n",
|
||
" <td>real estate</td>\n",
|
||
" <td>Real Estate</td>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2080 rows × 7 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Rank Name Networth Country \\\n",
|
||
"1909 1818 Tran Ba Duong & family 1.6 Vietnam \n",
|
||
"2099 2076 Mark Dixon 1.4 United Kingdom \n",
|
||
"1392 1341 Yingzhuo Xu 2.3 China \n",
|
||
"627 622 Bruce Flatt 4.6 Canada \n",
|
||
"527 523 Li Liangbin 5.2 China \n",
|
||
"... ... ... ... ... \n",
|
||
"84 85 Theo Albrecht, Jr. & family 18.7 Germany \n",
|
||
"633 622 Tony Tamer 4.6 United States \n",
|
||
"922 913 Bob Gaglardi 3.3 Canada \n",
|
||
"2178 2076 Eugene Wu 1.4 Taiwan \n",
|
||
"415 411 Leonard Stern 6.2 United States \n",
|
||
"\n",
|
||
" Source Industry Age_category \n",
|
||
"1909 automotive Automotive 60-70 \n",
|
||
"2099 office real estate Real Estate 60-70 \n",
|
||
"1392 agribusiness Food & Beverage 50-60 \n",
|
||
"627 money management Finance & Investments 50-60 \n",
|
||
"527 lithium Manufacturing 50-60 \n",
|
||
"... ... ... ... \n",
|
||
"84 Aldi, Trader Joe's Fashion & Retail 70-80 \n",
|
||
"633 private equity Finance & Investments 60-70 \n",
|
||
"922 hotels Real Estate 80+ \n",
|
||
"2178 finance Finance & Investments 70-80 \n",
|
||
"415 real estate Real Estate 80+ \n",
|
||
"\n",
|
||
"[2080 rows x 7 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_train'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Age_category</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1909</th>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2099</th>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1392</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>627</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>527</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>633</th>\n",
|
||
" <td>60-70</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>922</th>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2178</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>415</th>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2080 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Age_category\n",
|
||
"1909 60-70\n",
|
||
"2099 60-70\n",
|
||
"1392 50-60\n",
|
||
"627 50-60\n",
|
||
"527 50-60\n",
|
||
"... ...\n",
|
||
"84 70-80\n",
|
||
"633 60-70\n",
|
||
"922 80+\n",
|
||
"2178 70-80\n",
|
||
"415 80+\n",
|
||
"\n",
|
||
"[2080 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'X_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Rank</th>\n",
|
||
" <th>Name</th>\n",
|
||
" <th>Networth</th>\n",
|
||
" <th>Country</th>\n",
|
||
" <th>Source</th>\n",
|
||
" <th>Industry</th>\n",
|
||
" <th>Age_category</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2075</th>\n",
|
||
" <td>2076</td>\n",
|
||
" <td>Radhe Shyam Agarwal</td>\n",
|
||
" <td>1.4</td>\n",
|
||
" <td>India</td>\n",
|
||
" <td>consumer goods</td>\n",
|
||
" <td>Fashion & Retail</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1529</th>\n",
|
||
" <td>1513</td>\n",
|
||
" <td>Robert Duggan</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>United States</td>\n",
|
||
" <td>pharmaceuticals</td>\n",
|
||
" <td>Healthcare</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1803</th>\n",
|
||
" <td>1729</td>\n",
|
||
" <td>Yao Kuizhang</td>\n",
|
||
" <td>1.7</td>\n",
|
||
" <td>China</td>\n",
|
||
" <td>beverages</td>\n",
|
||
" <td>Food & Beverage</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>425</th>\n",
|
||
" <td>424</td>\n",
|
||
" <td>Alexei Kuzmichev</td>\n",
|
||
" <td>6.0</td>\n",
|
||
" <td>Russia</td>\n",
|
||
" <td>oil, banking, telecom</td>\n",
|
||
" <td>Energy</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2597</th>\n",
|
||
" <td>2578</td>\n",
|
||
" <td>Ramesh Genomal</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>Philippines</td>\n",
|
||
" <td>apparel</td>\n",
|
||
" <td>Fashion & Retail</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>935</th>\n",
|
||
" <td>913</td>\n",
|
||
" <td>Alfred Oetker</td>\n",
|
||
" <td>3.3</td>\n",
|
||
" <td>Germany</td>\n",
|
||
" <td>consumer goods</td>\n",
|
||
" <td>Fashion & Retail</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1541</th>\n",
|
||
" <td>1513</td>\n",
|
||
" <td>Thomas Lee</td>\n",
|
||
" <td>2.0</td>\n",
|
||
" <td>United States</td>\n",
|
||
" <td>private equity</td>\n",
|
||
" <td>Finance & Investments</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1646</th>\n",
|
||
" <td>1645</td>\n",
|
||
" <td>Roberto Angelini Rossi</td>\n",
|
||
" <td>1.8</td>\n",
|
||
" <td>Chile</td>\n",
|
||
" <td>forestry, mining</td>\n",
|
||
" <td>diversified</td>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>376</th>\n",
|
||
" <td>375</td>\n",
|
||
" <td>Patrick Drahi</td>\n",
|
||
" <td>6.6</td>\n",
|
||
" <td>France</td>\n",
|
||
" <td>telecom</td>\n",
|
||
" <td>Telecom</td>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1894</th>\n",
|
||
" <td>1818</td>\n",
|
||
" <td>Gerald Schwartz</td>\n",
|
||
" <td>1.6</td>\n",
|
||
" <td>Canada</td>\n",
|
||
" <td>finance</td>\n",
|
||
" <td>Finance & Investments</td>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>520 rows × 7 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Rank Name Networth Country \\\n",
|
||
"2075 2076 Radhe Shyam Agarwal 1.4 India \n",
|
||
"1529 1513 Robert Duggan 2.0 United States \n",
|
||
"1803 1729 Yao Kuizhang 1.7 China \n",
|
||
"425 424 Alexei Kuzmichev 6.0 Russia \n",
|
||
"2597 2578 Ramesh Genomal 1.0 Philippines \n",
|
||
"... ... ... ... ... \n",
|
||
"935 913 Alfred Oetker 3.3 Germany \n",
|
||
"1541 1513 Thomas Lee 2.0 United States \n",
|
||
"1646 1645 Roberto Angelini Rossi 1.8 Chile \n",
|
||
"376 375 Patrick Drahi 6.6 France \n",
|
||
"1894 1818 Gerald Schwartz 1.6 Canada \n",
|
||
"\n",
|
||
" Source Industry Age_category \n",
|
||
"2075 consumer goods Fashion & Retail 70-80 \n",
|
||
"1529 pharmaceuticals Healthcare 70-80 \n",
|
||
"1803 beverages Food & Beverage 50-60 \n",
|
||
"425 oil, banking, telecom Energy 50-60 \n",
|
||
"2597 apparel Fashion & Retail 70-80 \n",
|
||
"... ... ... ... \n",
|
||
"935 consumer goods Fashion & Retail 50-60 \n",
|
||
"1541 private equity Finance & Investments 70-80 \n",
|
||
"1646 forestry, mining diversified 70-80 \n",
|
||
"376 telecom Telecom 50-60 \n",
|
||
"1894 finance Finance & Investments 80+ \n",
|
||
"\n",
|
||
"[520 rows x 7 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"'y_test'"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>Age_category</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>2075</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1529</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1803</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>425</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2597</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>935</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1541</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1646</th>\n",
|
||
" <td>70-80</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>376</th>\n",
|
||
" <td>50-60</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1894</th>\n",
|
||
" <td>80+</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>520 rows × 1 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" Age_category\n",
|
||
"2075 70-80\n",
|
||
"1529 70-80\n",
|
||
"1803 50-60\n",
|
||
"425 50-60\n",
|
||
"2597 70-80\n",
|
||
"... ...\n",
|
||
"935 50-60\n",
|
||
"1541 70-80\n",
|
||
"1646 70-80\n",
|
||
"376 50-60\n",
|
||
"1894 80+\n",
|
||
"\n",
|
||
"[520 rows x 1 columns]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"from utils import split_stratified_into_train_val_test\n",
|
||
"\n",
|
||
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
|
||
" df, stratify_colname=\"Age_category\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
|
||
")\n",
|
||
"\n",
|
||
"display(\"X_train\", X_train)\n",
|
||
"display(\"y_train\", y_train)\n",
|
||
"\n",
|
||
"display(\"X_test\", X_test)\n",
|
||
"display(\"y_test\", y_test)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Формирование конвейера для классификации данных\n",
|
||
"## preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
|
||
"## preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
|
||
"## features_preprocessing -- трансформер для предобработки признаков\n",
|
||
"## features_engineering -- трансформер для конструирования признаков\n",
|
||
"## drop_columns -- трансформер для удаления колонок\n",
|
||
"## pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/html": [
|
||
"<div>\n",
|
||
"<style scoped>\n",
|
||
" .dataframe tbody tr th:only-of-type {\n",
|
||
" vertical-align: middle;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe tbody tr th {\n",
|
||
" vertical-align: top;\n",
|
||
" }\n",
|
||
"\n",
|
||
" .dataframe thead th {\n",
|
||
" text-align: right;\n",
|
||
" }\n",
|
||
"</style>\n",
|
||
"<table border=\"1\" class=\"dataframe\">\n",
|
||
" <thead>\n",
|
||
" <tr style=\"text-align: right;\">\n",
|
||
" <th></th>\n",
|
||
" <th>prepocessing_num__Networth</th>\n",
|
||
" <th>prepocessing_cat__Country_Argentina</th>\n",
|
||
" <th>prepocessing_cat__Country_Australia</th>\n",
|
||
" <th>prepocessing_cat__Country_Austria</th>\n",
|
||
" <th>prepocessing_cat__Country_Barbados</th>\n",
|
||
" <th>prepocessing_cat__Country_Belgium</th>\n",
|
||
" <th>prepocessing_cat__Country_Belize</th>\n",
|
||
" <th>prepocessing_cat__Country_Brazil</th>\n",
|
||
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
|
||
" <th>prepocessing_cat__Country_Canada</th>\n",
|
||
" <th>...</th>\n",
|
||
" <th>prepocessing_cat__Industry_Logistics</th>\n",
|
||
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
|
||
" <th>prepocessing_cat__Industry_Media & Entertainment</th>\n",
|
||
" <th>prepocessing_cat__Industry_Metals & Mining</th>\n",
|
||
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
|
||
" <th>prepocessing_cat__Industry_Service</th>\n",
|
||
" <th>prepocessing_cat__Industry_Sports</th>\n",
|
||
" <th>prepocessing_cat__Industry_Technology</th>\n",
|
||
" <th>prepocessing_cat__Industry_Telecom</th>\n",
|
||
" <th>prepocessing_cat__Industry_diversified</th>\n",
|
||
" </tr>\n",
|
||
" </thead>\n",
|
||
" <tbody>\n",
|
||
" <tr>\n",
|
||
" <th>1909</th>\n",
|
||
" <td>-0.309917</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2099</th>\n",
|
||
" <td>-0.329245</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>1392</th>\n",
|
||
" <td>-0.242268</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>627</th>\n",
|
||
" <td>-0.019995</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>527</th>\n",
|
||
" <td>0.037990</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>...</th>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>...</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>84</th>\n",
|
||
" <td>1.342637</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>633</th>\n",
|
||
" <td>-0.019995</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>922</th>\n",
|
||
" <td>-0.145628</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>2178</th>\n",
|
||
" <td>-0.329245</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" <tr>\n",
|
||
" <th>415</th>\n",
|
||
" <td>0.134630</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>...</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>1.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" <td>0.0</td>\n",
|
||
" </tr>\n",
|
||
" </tbody>\n",
|
||
"</table>\n",
|
||
"<p>2080 rows × 860 columns</p>\n",
|
||
"</div>"
|
||
],
|
||
"text/plain": [
|
||
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
|
||
"1909 -0.309917 0.0 \n",
|
||
"2099 -0.329245 0.0 \n",
|
||
"1392 -0.242268 0.0 \n",
|
||
"627 -0.019995 0.0 \n",
|
||
"527 0.037990 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 1.342637 0.0 \n",
|
||
"633 -0.019995 0.0 \n",
|
||
"922 -0.145628 0.0 \n",
|
||
"2178 -0.329245 0.0 \n",
|
||
"415 0.134630 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
|
||
"1909 0.0 0.0 \n",
|
||
"2099 0.0 0.0 \n",
|
||
"1392 0.0 0.0 \n",
|
||
"627 0.0 0.0 \n",
|
||
"527 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 0.0 0.0 \n",
|
||
"633 0.0 0.0 \n",
|
||
"922 0.0 0.0 \n",
|
||
"2178 0.0 0.0 \n",
|
||
"415 0.0 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
|
||
"1909 0.0 0.0 \n",
|
||
"2099 0.0 0.0 \n",
|
||
"1392 0.0 0.0 \n",
|
||
"627 0.0 0.0 \n",
|
||
"527 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 0.0 0.0 \n",
|
||
"633 0.0 0.0 \n",
|
||
"922 0.0 0.0 \n",
|
||
"2178 0.0 0.0 \n",
|
||
"415 0.0 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
|
||
"1909 0.0 0.0 \n",
|
||
"2099 0.0 0.0 \n",
|
||
"1392 0.0 0.0 \n",
|
||
"627 0.0 0.0 \n",
|
||
"527 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 0.0 0.0 \n",
|
||
"633 0.0 0.0 \n",
|
||
"922 0.0 0.0 \n",
|
||
"2178 0.0 0.0 \n",
|
||
"415 0.0 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
|
||
"1909 0.0 0.0 \n",
|
||
"2099 0.0 0.0 \n",
|
||
"1392 0.0 0.0 \n",
|
||
"627 0.0 1.0 \n",
|
||
"527 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 0.0 0.0 \n",
|
||
"633 0.0 0.0 \n",
|
||
"922 0.0 1.0 \n",
|
||
"2178 0.0 0.0 \n",
|
||
"415 0.0 0.0 \n",
|
||
"\n",
|
||
" ... prepocessing_cat__Industry_Logistics \\\n",
|
||
"1909 ... 0.0 \n",
|
||
"2099 ... 0.0 \n",
|
||
"1392 ... 0.0 \n",
|
||
"627 ... 0.0 \n",
|
||
"527 ... 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 ... 0.0 \n",
|
||
"633 ... 0.0 \n",
|
||
"922 ... 0.0 \n",
|
||
"2178 ... 0.0 \n",
|
||
"415 ... 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Manufacturing \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 1.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Media & Entertainment \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Metals & Mining \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Real Estate \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 1.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 1.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 1.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
|
||
"1909 0.0 0.0 \n",
|
||
"2099 0.0 0.0 \n",
|
||
"1392 0.0 0.0 \n",
|
||
"627 0.0 0.0 \n",
|
||
"527 0.0 0.0 \n",
|
||
"... ... ... \n",
|
||
"84 0.0 0.0 \n",
|
||
"633 0.0 0.0 \n",
|
||
"922 0.0 0.0 \n",
|
||
"2178 0.0 0.0 \n",
|
||
"415 0.0 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Technology \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_Telecom \\\n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
" prepocessing_cat__Industry_diversified \n",
|
||
"1909 0.0 \n",
|
||
"2099 0.0 \n",
|
||
"1392 0.0 \n",
|
||
"627 0.0 \n",
|
||
"527 0.0 \n",
|
||
"... ... \n",
|
||
"84 0.0 \n",
|
||
"633 0.0 \n",
|
||
"922 0.0 \n",
|
||
"2178 0.0 \n",
|
||
"415 0.0 \n",
|
||
"\n",
|
||
"[2080 rows x 860 columns]"
|
||
]
|
||
},
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
|
||
"columns_to_drop = [\"Age_category\", \"Rank \", \"Name\"]\n",
|
||
"\n",
|
||
"num_columns = [\n",
|
||
" column\n",
|
||
" for column in X_train.columns\n",
|
||
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
|
||
"]\n",
|
||
"cat_columns = [\n",
|
||
" column\n",
|
||
" for column in X_train.columns\n",
|
||
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
|
||
"]\n",
|
||
"\n",
|
||
"# Предобработка числовых данных\n",
|
||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||
"num_scaler = StandardScaler()\n",
|
||
"preprocessing_num = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", num_imputer),\n",
|
||
" (\"scaler\", num_scaler),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Предобработка категориальных данных\n",
|
||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||
"preprocessing_cat = Pipeline(\n",
|
||
" [\n",
|
||
" (\"imputer\", cat_imputer),\n",
|
||
" (\"encoder\", cat_encoder),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Общая предобработка признаков\n",
|
||
"features_preprocessing = ColumnTransformer(\n",
|
||
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
|
||
" transformers=[\n",
|
||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||
" ],\n",
|
||
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
|
||
")\n",
|
||
"\n",
|
||
"# Итоговый конвейер\n",
|
||
"pipeline_end = Pipeline(\n",
|
||
" [\n",
|
||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||
" ]\n",
|
||
")\n",
|
||
"\n",
|
||
"# Преобразуем данные\n",
|
||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||
"\n",
|
||
"# Создаем DataFrame с правильными именами колонок\n",
|
||
"preprocessed_df = pd.DataFrame(\n",
|
||
" preprocessing_result,\n",
|
||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||
" index=X_train.index, # Сохраняем индексы\n",
|
||
")\n",
|
||
"\n",
|
||
"preprocessed_df"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Формирование набора моделей для классификации\n",
|
||
"## logistic -- логистическая регрессия\n",
|
||
"## ridge -- гребневая регрессия\n",
|
||
"## decision_tree -- дерево решений\n",
|
||
"## knn -- k-ближайших соседей\n",
|
||
"## naive_bayes -- наивный Байесовский классификатор\n",
|
||
"## gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
|
||
"## random_forest -- метод случайного леса (набор деревьев решений)\n",
|
||
"## mlp -- многослойный персептрон (нейронная сеть)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
|
||
"\n",
|
||
"class_models = {\n",
|
||
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
|
||
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
|
||
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
|
||
" \"decision_tree\": {\n",
|
||
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
|
||
" },\n",
|
||
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
|
||
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
|
||
" \"gradient_boosting\": {\n",
|
||
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
|
||
" },\n",
|
||
" \"random_forest\": {\n",
|
||
" \"model\": ensemble.RandomForestClassifier(\n",
|
||
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
|
||
" )\n",
|
||
" },\n",
|
||
" \"mlp\": {\n",
|
||
" \"model\": neural_network.MLPClassifier(\n",
|
||
" hidden_layer_sizes=(7,),\n",
|
||
" max_iter=500,\n",
|
||
" early_stopping=True,\n",
|
||
" random_state=9,\n",
|
||
" )\n",
|
||
" },\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Обучение моделей на обучающем наборе данных и оценка на тестовом"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: logistic\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
|
||
" warnings.warn(\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "ValueError",
|
||
"evalue": "Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[40], line 19\u001b[0m\n\u001b[0;32m 16\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprobs\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m y_test_probs\n\u001b[0;32m 17\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m y_test_predict\n\u001b[1;32m---> 19\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision_train\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_score\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train_predict\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 22\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision_test\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m metrics\u001b[38;5;241m.\u001b[39mprecision_score(\n\u001b[0;32m 23\u001b[0m y_test, y_test_predict\n\u001b[0;32m 24\u001b[0m )\n\u001b[0;32m 25\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecall_train\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m metrics\u001b[38;5;241m.\u001b[39mrecall_score(\n\u001b[0;32m 26\u001b[0m y_train, y_train_predict\n\u001b[0;32m 27\u001b[0m )\n",
|
||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 208\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 209\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 210\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 211\u001b[0m )\n\u001b[0;32m 212\u001b[0m ):\n\u001b[1;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m 215\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[0;32m 216\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[0;32m 219\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[0;32m 220\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 221\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 222\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[0;32m 223\u001b[0m )\n",
|
||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:2204\u001b[0m, in \u001b[0;36mprecision_score\u001b[1;34m(y_true, y_pred, labels, pos_label, average, sample_weight, zero_division)\u001b[0m\n\u001b[0;32m 2037\u001b[0m \u001b[38;5;129m@validate_params\u001b[39m(\n\u001b[0;32m 2038\u001b[0m {\n\u001b[0;32m 2039\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_true\u001b[39m\u001b[38;5;124m\"\u001b[39m: [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marray-like\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msparse matrix\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2064\u001b[0m zero_division\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwarn\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 2065\u001b[0m ):\n\u001b[0;32m 2066\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute the precision.\u001b[39;00m\n\u001b[0;32m 2067\u001b[0m \n\u001b[0;32m 2068\u001b[0m \u001b[38;5;124;03m The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2202\u001b[0m \u001b[38;5;124;03m array([0.5, 1. , 1. ])\u001b[39;00m\n\u001b[0;32m 2203\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 2204\u001b[0m p, _, _, _ \u001b[38;5;241m=\u001b[39m \u001b[43mprecision_recall_fscore_support\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2205\u001b[0m \u001b[43m \u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2206\u001b[0m \u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2207\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2208\u001b[0m \u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpos_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2209\u001b[0m \u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2210\u001b[0m \u001b[43m \u001b[49m\u001b[43mwarn_for\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mprecision\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2211\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2212\u001b[0m \u001b[43m \u001b[49m\u001b[43mzero_division\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mzero_division\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2213\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m p\n",
|
||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py:186\u001b[0m, in \u001b[0;36mvalidate_params.<locals>.decorator.<locals>.wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 184\u001b[0m global_skip_validation \u001b[38;5;241m=\u001b[39m get_config()[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mskip_parameter_validation\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m global_skip_validation:\n\u001b[1;32m--> 186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 188\u001b[0m func_sig \u001b[38;5;241m=\u001b[39m signature(func)\n\u001b[0;32m 190\u001b[0m \u001b[38;5;66;03m# Map *args/**kwargs to the function signature\u001b[39;00m\n",
|
||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1789\u001b[0m, in \u001b[0;36mprecision_recall_fscore_support\u001b[1;34m(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight, zero_division)\u001b[0m\n\u001b[0;32m 1626\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Compute precision, recall, F-measure and support for each class.\u001b[39;00m\n\u001b[0;32m 1627\u001b[0m \n\u001b[0;32m 1628\u001b[0m \u001b[38;5;124;03mThe precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1786\u001b[0m \u001b[38;5;124;03m array([2, 2, 2]))\u001b[39;00m\n\u001b[0;32m 1787\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 1788\u001b[0m _check_zero_division(zero_division)\n\u001b[1;32m-> 1789\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[43m_check_set_wise_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1791\u001b[0m \u001b[38;5;66;03m# Calculate tp_sum, pred_sum, true_sum ###\u001b[39;00m\n\u001b[0;32m 1792\u001b[0m samplewise \u001b[38;5;241m=\u001b[39m average \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
|
||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1578\u001b[0m, in \u001b[0;36m_check_set_wise_labels\u001b[1;34m(y_true, y_pred, average, labels, pos_label)\u001b[0m\n\u001b[0;32m 1576\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulticlass\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m 1577\u001b[0m average_options\u001b[38;5;241m.\u001b[39mremove(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m-> 1578\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1579\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTarget is \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m but average=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1580\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoose another average setting, one of \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (y_type, average_options)\n\u001b[0;32m 1581\u001b[0m )\n\u001b[0;32m 1582\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m pos_label \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m1\u001b[39m):\n\u001b[0;32m 1583\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m 1584\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNote that pos_label (set to \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) is ignored when \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maverage != \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m (got \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m). You may use \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1588\u001b[0m \u001b[38;5;167;01mUserWarning\u001b[39;00m,\n\u001b[0;32m 1589\u001b[0m )\n",
|
||
"\u001b[1;31mValueError\u001b[0m: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']."
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from sklearn import metrics\n",
|
||
"\n",
|
||
"for model_name in class_models.keys():\n",
|
||
" print(f\"Model: {model_name}\")\n",
|
||
" model = class_models[model_name][\"model\"]\n",
|
||
"\n",
|
||
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
|
||
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
|
||
"\n",
|
||
" y_train_predict = model_pipeline.predict(X_train)\n",
|
||
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
|
||
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
|
||
"\n",
|
||
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
|
||
" class_models[model_name][\"probs\"] = y_test_probs\n",
|
||
" class_models[model_name][\"preds\"] = y_test_predict\n",
|
||
"\n",
|
||
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
|
||
" y_train, y_train_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
|
||
" y_test, y_test_probs\n",
|
||
" )\n",
|
||
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
|
||
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
|
||
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )\n",
|
||
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
|
||
" y_test, y_test_predict\n",
|
||
" )"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": ".venv",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.12.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|