{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Определим бизнес цели:\n",
"## 1- Прогнозирование возраста миллиардера(классификация)\n",
"## 2- Прогнозирование состояния миллиардера(регрессия)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Подготовим данные: категоризируем колонку age"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Rank 0\n",
"Name 0\n",
"Networth 0\n",
"Age 0\n",
"Country 0\n",
"Source 0\n",
"Industry 0\n",
"dtype: int64\n",
"\n",
"Rank False\n",
"Name False\n",
"Networth False\n",
"Age False\n",
"Country False\n",
"Source False\n",
"Industry False\n",
"dtype: bool\n",
"\n"
]
}
],
"source": [
"print(df.isnull().sum())\n",
"\n",
"print()\n",
"\n",
"# Есть ли пустые значения признаков\n",
"print(df.isnull().any())\n",
"\n",
"print()\n",
"\n",
"# Процент пустых значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Rank Name Networth Country \\\n",
"0 1 Elon Musk 219.0 United States \n",
"1 2 Jeff Bezos 171.0 United States \n",
"2 3 Bernard Arnault & family 158.0 France \n",
"3 4 Bill Gates 129.0 United States \n",
"4 5 Warren Buffett 118.0 United States \n",
"\n",
" Source Industry Age_category \n",
"0 Tesla, SpaceX Automotive 50-60 \n",
"1 Amazon Technology 50-60 \n",
"2 LVMH Fashion & Retail 70-80 \n",
"3 Microsoft Technology 60-70 \n",
"4 Berkshire Hathaway Finance & Investments 80+ \n"
]
}
],
"source": [
"\n",
"\n",
"bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n",
"labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n",
"\n",
"df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n",
"# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n",
"df.drop(columns=['Age'], inplace=True)\n",
"\n",
"# Просмотр результата\n",
"print(df.head())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Rank | \n",
" Name | \n",
" Networth | \n",
" Country | \n",
" Source | \n",
" Industry | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 1909 | \n",
" 1818 | \n",
" Tran Ba Duong & family | \n",
" 1.6 | \n",
" Vietnam | \n",
" automotive | \n",
" Automotive | \n",
" 60-70 | \n",
"
\n",
" \n",
" 2099 | \n",
" 2076 | \n",
" Mark Dixon | \n",
" 1.4 | \n",
" United Kingdom | \n",
" office real estate | \n",
" Real Estate | \n",
" 60-70 | \n",
"
\n",
" \n",
" 1392 | \n",
" 1341 | \n",
" Yingzhuo Xu | \n",
" 2.3 | \n",
" China | \n",
" agribusiness | \n",
" Food & Beverage | \n",
" 50-60 | \n",
"
\n",
" \n",
" 627 | \n",
" 622 | \n",
" Bruce Flatt | \n",
" 4.6 | \n",
" Canada | \n",
" money management | \n",
" Finance & Investments | \n",
" 50-60 | \n",
"
\n",
" \n",
" 527 | \n",
" 523 | \n",
" Li Liangbin | \n",
" 5.2 | \n",
" China | \n",
" lithium | \n",
" Manufacturing | \n",
" 50-60 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 84 | \n",
" 85 | \n",
" Theo Albrecht, Jr. & family | \n",
" 18.7 | \n",
" Germany | \n",
" Aldi, Trader Joe's | \n",
" Fashion & Retail | \n",
" 70-80 | \n",
"
\n",
" \n",
" 633 | \n",
" 622 | \n",
" Tony Tamer | \n",
" 4.6 | \n",
" United States | \n",
" private equity | \n",
" Finance & Investments | \n",
" 60-70 | \n",
"
\n",
" \n",
" 922 | \n",
" 913 | \n",
" Bob Gaglardi | \n",
" 3.3 | \n",
" Canada | \n",
" hotels | \n",
" Real Estate | \n",
" 80+ | \n",
"
\n",
" \n",
" 2178 | \n",
" 2076 | \n",
" Eugene Wu | \n",
" 1.4 | \n",
" Taiwan | \n",
" finance | \n",
" Finance & Investments | \n",
" 70-80 | \n",
"
\n",
" \n",
" 415 | \n",
" 411 | \n",
" Leonard Stern | \n",
" 6.2 | \n",
" United States | \n",
" real estate | \n",
" Real Estate | \n",
" 80+ | \n",
"
\n",
" \n",
"
\n",
"
2080 rows × 7 columns
\n",
"
"
],
"text/plain": [
" Rank Name Networth Country \\\n",
"1909 1818 Tran Ba Duong & family 1.6 Vietnam \n",
"2099 2076 Mark Dixon 1.4 United Kingdom \n",
"1392 1341 Yingzhuo Xu 2.3 China \n",
"627 622 Bruce Flatt 4.6 Canada \n",
"527 523 Li Liangbin 5.2 China \n",
"... ... ... ... ... \n",
"84 85 Theo Albrecht, Jr. & family 18.7 Germany \n",
"633 622 Tony Tamer 4.6 United States \n",
"922 913 Bob Gaglardi 3.3 Canada \n",
"2178 2076 Eugene Wu 1.4 Taiwan \n",
"415 411 Leonard Stern 6.2 United States \n",
"\n",
" Source Industry Age_category \n",
"1909 automotive Automotive 60-70 \n",
"2099 office real estate Real Estate 60-70 \n",
"1392 agribusiness Food & Beverage 50-60 \n",
"627 money management Finance & Investments 50-60 \n",
"527 lithium Manufacturing 50-60 \n",
"... ... ... ... \n",
"84 Aldi, Trader Joe's Fashion & Retail 70-80 \n",
"633 private equity Finance & Investments 60-70 \n",
"922 hotels Real Estate 80+ \n",
"2178 finance Finance & Investments 70-80 \n",
"415 real estate Real Estate 80+ \n",
"\n",
"[2080 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 1909 | \n",
" 60-70 | \n",
"
\n",
" \n",
" 2099 | \n",
" 60-70 | \n",
"
\n",
" \n",
" 1392 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 627 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 527 | \n",
" 50-60 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 84 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 633 | \n",
" 60-70 | \n",
"
\n",
" \n",
" 922 | \n",
" 80+ | \n",
"
\n",
" \n",
" 2178 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 415 | \n",
" 80+ | \n",
"
\n",
" \n",
"
\n",
"
2080 rows × 1 columns
\n",
"
"
],
"text/plain": [
" Age_category\n",
"1909 60-70\n",
"2099 60-70\n",
"1392 50-60\n",
"627 50-60\n",
"527 50-60\n",
"... ...\n",
"84 70-80\n",
"633 60-70\n",
"922 80+\n",
"2178 70-80\n",
"415 80+\n",
"\n",
"[2080 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Rank | \n",
" Name | \n",
" Networth | \n",
" Country | \n",
" Source | \n",
" Industry | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 2075 | \n",
" 2076 | \n",
" Radhe Shyam Agarwal | \n",
" 1.4 | \n",
" India | \n",
" consumer goods | \n",
" Fashion & Retail | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1529 | \n",
" 1513 | \n",
" Robert Duggan | \n",
" 2.0 | \n",
" United States | \n",
" pharmaceuticals | \n",
" Healthcare | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1803 | \n",
" 1729 | \n",
" Yao Kuizhang | \n",
" 1.7 | \n",
" China | \n",
" beverages | \n",
" Food & Beverage | \n",
" 50-60 | \n",
"
\n",
" \n",
" 425 | \n",
" 424 | \n",
" Alexei Kuzmichev | \n",
" 6.0 | \n",
" Russia | \n",
" oil, banking, telecom | \n",
" Energy | \n",
" 50-60 | \n",
"
\n",
" \n",
" 2597 | \n",
" 2578 | \n",
" Ramesh Genomal | \n",
" 1.0 | \n",
" Philippines | \n",
" apparel | \n",
" Fashion & Retail | \n",
" 70-80 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 935 | \n",
" 913 | \n",
" Alfred Oetker | \n",
" 3.3 | \n",
" Germany | \n",
" consumer goods | \n",
" Fashion & Retail | \n",
" 50-60 | \n",
"
\n",
" \n",
" 1541 | \n",
" 1513 | \n",
" Thomas Lee | \n",
" 2.0 | \n",
" United States | \n",
" private equity | \n",
" Finance & Investments | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1646 | \n",
" 1645 | \n",
" Roberto Angelini Rossi | \n",
" 1.8 | \n",
" Chile | \n",
" forestry, mining | \n",
" diversified | \n",
" 70-80 | \n",
"
\n",
" \n",
" 376 | \n",
" 375 | \n",
" Patrick Drahi | \n",
" 6.6 | \n",
" France | \n",
" telecom | \n",
" Telecom | \n",
" 50-60 | \n",
"
\n",
" \n",
" 1894 | \n",
" 1818 | \n",
" Gerald Schwartz | \n",
" 1.6 | \n",
" Canada | \n",
" finance | \n",
" Finance & Investments | \n",
" 80+ | \n",
"
\n",
" \n",
"
\n",
"
520 rows × 7 columns
\n",
"
"
],
"text/plain": [
" Rank Name Networth Country \\\n",
"2075 2076 Radhe Shyam Agarwal 1.4 India \n",
"1529 1513 Robert Duggan 2.0 United States \n",
"1803 1729 Yao Kuizhang 1.7 China \n",
"425 424 Alexei Kuzmichev 6.0 Russia \n",
"2597 2578 Ramesh Genomal 1.0 Philippines \n",
"... ... ... ... ... \n",
"935 913 Alfred Oetker 3.3 Germany \n",
"1541 1513 Thomas Lee 2.0 United States \n",
"1646 1645 Roberto Angelini Rossi 1.8 Chile \n",
"376 375 Patrick Drahi 6.6 France \n",
"1894 1818 Gerald Schwartz 1.6 Canada \n",
"\n",
" Source Industry Age_category \n",
"2075 consumer goods Fashion & Retail 70-80 \n",
"1529 pharmaceuticals Healthcare 70-80 \n",
"1803 beverages Food & Beverage 50-60 \n",
"425 oil, banking, telecom Energy 50-60 \n",
"2597 apparel Fashion & Retail 70-80 \n",
"... ... ... ... \n",
"935 consumer goods Fashion & Retail 50-60 \n",
"1541 private equity Finance & Investments 70-80 \n",
"1646 forestry, mining diversified 70-80 \n",
"376 telecom Telecom 50-60 \n",
"1894 finance Finance & Investments 80+ \n",
"\n",
"[520 rows x 7 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 2075 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1529 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1803 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 425 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 2597 | \n",
" 70-80 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 935 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 1541 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 1646 | \n",
" 70-80 | \n",
"
\n",
" \n",
" 376 | \n",
" 50-60 | \n",
"
\n",
" \n",
" 1894 | \n",
" 80+ | \n",
"
\n",
" \n",
"
\n",
"
520 rows × 1 columns
\n",
"
"
],
"text/plain": [
" Age_category\n",
"2075 70-80\n",
"1529 70-80\n",
"1803 50-60\n",
"425 50-60\n",
"2597 70-80\n",
"... ...\n",
"935 50-60\n",
"1541 70-80\n",
"1646 70-80\n",
"376 50-60\n",
"1894 80+\n",
"\n",
"[520 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from utils import split_stratified_into_train_val_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Age_category\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование конвейера для классификации данных\n",
"## preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"## preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"## features_preprocessing -- трансформер для предобработки признаков\n",
"## features_engineering -- трансформер для конструирования признаков\n",
"## drop_columns -- трансформер для удаления колонок\n",
"## pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" prepocessing_num__Networth | \n",
" prepocessing_cat__Country_Argentina | \n",
" prepocessing_cat__Country_Australia | \n",
" prepocessing_cat__Country_Austria | \n",
" prepocessing_cat__Country_Barbados | \n",
" prepocessing_cat__Country_Belgium | \n",
" prepocessing_cat__Country_Belize | \n",
" prepocessing_cat__Country_Brazil | \n",
" prepocessing_cat__Country_Bulgaria | \n",
" prepocessing_cat__Country_Canada | \n",
" ... | \n",
" prepocessing_cat__Industry_Logistics | \n",
" prepocessing_cat__Industry_Manufacturing | \n",
" prepocessing_cat__Industry_Media & Entertainment | \n",
" prepocessing_cat__Industry_Metals & Mining | \n",
" prepocessing_cat__Industry_Real Estate | \n",
" prepocessing_cat__Industry_Service | \n",
" prepocessing_cat__Industry_Sports | \n",
" prepocessing_cat__Industry_Technology | \n",
" prepocessing_cat__Industry_Telecom | \n",
" prepocessing_cat__Industry_diversified | \n",
"
\n",
" \n",
" \n",
" \n",
" 1909 | \n",
" -0.309917 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2099 | \n",
" -0.329245 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1392 | \n",
" -0.242268 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 627 | \n",
" -0.019995 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 527 | \n",
" 0.037990 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 84 | \n",
" 1.342637 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 633 | \n",
" -0.019995 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 922 | \n",
" -0.145628 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2178 | \n",
" -0.329245 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 415 | \n",
" 0.134630 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
2080 rows × 860 columns
\n",
"
"
],
"text/plain": [
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
"1909 -0.309917 0.0 \n",
"2099 -0.329245 0.0 \n",
"1392 -0.242268 0.0 \n",
"627 -0.019995 0.0 \n",
"527 0.037990 0.0 \n",
"... ... ... \n",
"84 1.342637 0.0 \n",
"633 -0.019995 0.0 \n",
"922 -0.145628 0.0 \n",
"2178 -0.329245 0.0 \n",
"415 0.134630 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 1.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 1.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"1909 ... 0.0 \n",
"2099 ... 0.0 \n",
"1392 ... 0.0 \n",
"627 ... 0.0 \n",
"527 ... 0.0 \n",
"... ... ... \n",
"84 ... 0.0 \n",
"633 ... 0.0 \n",
"922 ... 0.0 \n",
"2178 ... 0.0 \n",
"415 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 1.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"1909 0.0 \n",
"2099 1.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 1.0 \n",
"2178 0.0 \n",
"415 1.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"1909 0.0 0.0 \n",
"2099 0.0 0.0 \n",
"1392 0.0 0.0 \n",
"627 0.0 0.0 \n",
"527 0.0 0.0 \n",
"... ... ... \n",
"84 0.0 0.0 \n",
"633 0.0 0.0 \n",
"922 0.0 0.0 \n",
"2178 0.0 0.0 \n",
"415 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"1909 0.0 \n",
"2099 0.0 \n",
"1392 0.0 \n",
"627 0.0 \n",
"527 0.0 \n",
"... ... \n",
"84 0.0 \n",
"633 0.0 \n",
"922 0.0 \n",
"2178 0.0 \n",
"415 0.0 \n",
"\n",
"[2080 rows x 860 columns]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import pandas as pd\n",
"\n",
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = [\"Age_category\", \"Rank \", \"Name\"]\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование набора моделей для классификации\n",
"## logistic -- логистическая регрессия\n",
"## ridge -- гребневая регрессия\n",
"## decision_tree -- дерево решений\n",
"## knn -- k-ближайших соседей\n",
"## naive_bayes -- наивный Байесовский классификатор\n",
"## gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"## random_forest -- метод случайного леса (набор деревьев решений)\n",
"## mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Обучение моделей на обучающем наборе данных и оценка на тестовом"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"y_train['Age_category'] = y_train['Age_category'].cat.codes\n",
"y_test['Age_category'] = y_test['Age_category'].cat.codes"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\neighbors\\_classification.py:238: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return self._fit(X, y)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: decision_tree\n",
"Model: knn\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\validation.py:1339: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: naive_bayes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_label.py:114: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: gradient_boosting\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: random_forest\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\neural_network\\_multilayer_perceptron.py:1105: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel().\n",
" y = column_or_1d(y, warn=True)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train)\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)\n",
" y_test_predict = np.argmax(y_test_probs, axis=1)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Метрики\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs, multi_class=\"ovr\"\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(\n",
" y_train, y_train_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(\n",
" y_test, y_test_predict, average=\"macro\"\n",
" )\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Сводная таблица оценок качества для использованных моделей классификации"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(17, 17), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Under 30\", \"30-40\", \"40-50\", \"50-60\", \"60-70\", \"70-80\", \"80+\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" Precision_train | \n",
" Precision_test | \n",
" Recall_train | \n",
" Recall_test | \n",
" Accuracy_train | \n",
" Accuracy_test | \n",
" F1_train | \n",
" F1_test | \n",
"
\n",
" \n",
" \n",
" \n",
" logistic | \n",
" 0.567471 | \n",
" 0.278074 | \n",
" 0.444883 | \n",
" 0.232769 | \n",
" 0.618269 | \n",
" 0.353846 | \n",
" 0.465219 | \n",
" 0.237759 | \n",
"
\n",
" \n",
" gradient_boosting | \n",
" 0.836061 | \n",
" 0.287405 | \n",
" 0.725411 | \n",
" 0.235795 | \n",
" 0.689904 | \n",
" 0.344231 | \n",
" 0.760847 | \n",
" 0.240251 | \n",
"
\n",
" \n",
" knn | \n",
" 0.477783 | \n",
" 0.221788 | \n",
" 0.460090 | \n",
" 0.214239 | \n",
" 0.497115 | \n",
" 0.328846 | \n",
" 0.456182 | \n",
" 0.211556 | \n",
"
\n",
" \n",
" decision_tree | \n",
" 0.618281 | \n",
" 0.163157 | \n",
" 0.244223 | \n",
" 0.184231 | \n",
" 0.387981 | \n",
" 0.325000 | \n",
" 0.227570 | \n",
" 0.146479 | \n",
"
\n",
" \n",
" random_forest | \n",
" 0.581578 | \n",
" 0.236539 | \n",
" 0.735419 | \n",
" 0.246556 | \n",
" 0.627404 | \n",
" 0.288462 | \n",
" 0.599765 | \n",
" 0.231541 | \n",
"
\n",
" \n",
" ridge | \n",
" 0.518033 | \n",
" 0.238462 | \n",
" 0.695673 | \n",
" 0.247678 | \n",
" 0.556250 | \n",
" 0.284615 | \n",
" 0.553233 | \n",
" 0.226955 | \n",
"
\n",
" \n",
" mlp | \n",
" 0.035714 | \n",
" 0.035714 | \n",
" 0.142857 | \n",
" 0.142857 | \n",
" 0.250000 | \n",
" 0.250000 | \n",
" 0.057143 | \n",
" 0.057143 | \n",
"
\n",
" \n",
" naive_bayes | \n",
" 0.524162 | \n",
" 0.239277 | \n",
" 0.664585 | \n",
" 0.202308 | \n",
" 0.494231 | \n",
" 0.176923 | \n",
" 0.465319 | \n",
" 0.151713 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## значения далеки от идела, датасет так себе..."
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting Jinja2\n",
" Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)\n",
"Collecting MarkupSafe>=2.0 (from Jinja2)\n",
" Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl.metadata (4.1 kB)\n",
"Downloading jinja2-3.1.4-py3-none-any.whl (133 kB)\n",
"Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl (15 kB)\n",
"Installing collected packages: MarkupSafe, Jinja2\n",
"Successfully installed Jinja2-3.1.4 MarkupSafe-3.0.2\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"[notice] A new release of pip is available: 24.2 -> 24.3.1\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
]
}
],
"source": [
"pip install Jinja2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" Accuracy_test | \n",
" F1_test | \n",
" ROC_AUC_test | \n",
" Cohen_kappa_test | \n",
" MCC_test | \n",
"
\n",
" \n",
" \n",
" \n",
" gradient_boosting | \n",
" 0.344231 | \n",
" 0.240251 | \n",
" 0.649816 | \n",
" 0.131708 | \n",
" 0.138628 | \n",
"
\n",
" \n",
" logistic | \n",
" 0.353846 | \n",
" 0.237759 | \n",
" 0.615478 | \n",
" 0.160238 | \n",
" 0.161282 | \n",
"
\n",
" \n",
" ridge | \n",
" 0.284615 | \n",
" 0.226955 | \n",
" 0.612260 | \n",
" 0.129672 | \n",
" 0.133551 | \n",
"
\n",
" \n",
" knn | \n",
" 0.328846 | \n",
" 0.211556 | \n",
" 0.602333 | \n",
" 0.128794 | \n",
" 0.130205 | \n",
"
\n",
" \n",
" random_forest | \n",
" 0.288462 | \n",
" 0.231541 | \n",
" 0.599541 | \n",
" 0.126828 | \n",
" 0.129917 | \n",
"
\n",
" \n",
" decision_tree | \n",
" 0.325000 | \n",
" 0.146479 | \n",
" 0.581718 | \n",
" 0.078698 | \n",
" 0.098279 | \n",
"
\n",
" \n",
" naive_bayes | \n",
" 0.176923 | \n",
" 0.151713 | \n",
" 0.562024 | \n",
" 0.071080 | \n",
" 0.079232 | \n",
"
\n",
" \n",
" mlp | \n",
" 0.250000 | \n",
" 0.057143 | \n",
" 0.554978 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Вывод данных с ошибкой предсказания для оценки"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"'Error items count: 336'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Rank | \n",
" Predicted | \n",
" Name | \n",
" Networth | \n",
" Country | \n",
" Source | \n",
" Industry | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 6 | \n",
" 7 | \n",
" 4 | \n",
" Sergey Brin | \n",
" 107.0 | \n",
" United States | \n",
" Google | \n",
" Technology | \n",
" 40-50 | \n",
"
\n",
" \n",
" 8 | \n",
" 9 | \n",
" 3 | \n",
" Steve Ballmer | \n",
" 91.4 | \n",
" United States | \n",
" Microsoft | \n",
" Technology | \n",
" 60-70 | \n",
"
\n",
" \n",
" 12 | \n",
" 13 | \n",
" 3 | \n",
" Carlos Slim Helu & family | \n",
" 81.2 | \n",
" Mexico | \n",
" telecom | \n",
" Telecom | \n",
" 80+ | \n",
"
\n",
" \n",
" 14 | \n",
" 15 | \n",
" 3 | \n",
" Mark Zuckerberg | \n",
" 67.3 | \n",
" United States | \n",
" Facebook | \n",
" Technology | \n",
" 30-40 | \n",
"
\n",
" \n",
" 22 | \n",
" 23 | \n",
" 5 | \n",
" Amancio Ortega | \n",
" 59.6 | \n",
" Spain | \n",
" Zara | \n",
" Fashion & Retail | \n",
" 80+ | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 2586 | \n",
" 2578 | \n",
" 3 | \n",
" Roy Chi Ping Chung | \n",
" 1.0 | \n",
" Hong Kong | \n",
" manufacturing | \n",
" Manufacturing | \n",
" 60-70 | \n",
"
\n",
" \n",
" 2588 | \n",
" 2578 | \n",
" 3 | \n",
" Ronald Clarke | \n",
" 1.0 | \n",
" United States | \n",
" payments technology | \n",
" Technology | \n",
" 60-70 | \n",
"
\n",
" \n",
" 2591 | \n",
" 2578 | \n",
" 5 | \n",
" Sefik Yilmaz Dizdar | \n",
" 1.0 | \n",
" Turkey | \n",
" fashion retail | \n",
" Fashion & Retail | \n",
" 80+ | \n",
"
\n",
" \n",
" 2593 | \n",
" 2578 | \n",
" 6 | \n",
" Larry Fink | \n",
" 1.0 | \n",
" United States | \n",
" money management | \n",
" Finance & Investments | \n",
" 60-70 | \n",
"
\n",
" \n",
" 2596 | \n",
" 2578 | \n",
" 5 | \n",
" Nari Genomal | \n",
" 1.0 | \n",
" Philippines | \n",
" apparel | \n",
" Fashion & Retail | \n",
" 80+ | \n",
"
\n",
" \n",
"
\n",
"
336 rows × 8 columns
\n",
"
"
],
"text/plain": [
" Rank Predicted Name Networth Country \\\n",
"6 7 4 Sergey Brin 107.0 United States \n",
"8 9 3 Steve Ballmer 91.4 United States \n",
"12 13 3 Carlos Slim Helu & family 81.2 Mexico \n",
"14 15 3 Mark Zuckerberg 67.3 United States \n",
"22 23 5 Amancio Ortega 59.6 Spain \n",
"... ... ... ... ... ... \n",
"2586 2578 3 Roy Chi Ping Chung 1.0 Hong Kong \n",
"2588 2578 3 Ronald Clarke 1.0 United States \n",
"2591 2578 5 Sefik Yilmaz Dizdar 1.0 Turkey \n",
"2593 2578 6 Larry Fink 1.0 United States \n",
"2596 2578 5 Nari Genomal 1.0 Philippines \n",
"\n",
" Source Industry Age_category \n",
"6 Google Technology 40-50 \n",
"8 Microsoft Technology 60-70 \n",
"12 telecom Telecom 80+ \n",
"14 Facebook Technology 30-40 \n",
"22 Zara Fashion & Retail 80+ \n",
"... ... ... ... \n",
"2586 manufacturing Manufacturing 60-70 \n",
"2588 payments technology Technology 60-70 \n",
"2591 fashion retail Fashion & Retail 80+ \n",
"2593 money management Finance & Investments 60-70 \n",
"2596 apparel Fashion & Retail 80+ \n",
"\n",
"[336 rows x 8 columns]"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Age_category\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Многовато..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Пример использования обученной модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Rank | \n",
" Name | \n",
" Networth | \n",
" Country | \n",
" Source | \n",
" Industry | \n",
" Age_category | \n",
"
\n",
" \n",
" \n",
" \n",
" 450 | \n",
" 438 | \n",
" Ruan Liping | \n",
" 5.8 | \n",
" Hong Kong | \n",
" power strips | \n",
" Manufacturing | \n",
" 50-60 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Rank Name Networth Country Source Industry \\\n",
"450 438 Ruan Liping 5.8 Hong Kong power strips Manufacturing \n",
"\n",
" Age_category \n",
"450 50-60 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" prepocessing_num__Networth | \n",
" prepocessing_cat__Country_Argentina | \n",
" prepocessing_cat__Country_Australia | \n",
" prepocessing_cat__Country_Austria | \n",
" prepocessing_cat__Country_Barbados | \n",
" prepocessing_cat__Country_Belgium | \n",
" prepocessing_cat__Country_Belize | \n",
" prepocessing_cat__Country_Brazil | \n",
" prepocessing_cat__Country_Bulgaria | \n",
" prepocessing_cat__Country_Canada | \n",
" ... | \n",
" prepocessing_cat__Industry_Logistics | \n",
" prepocessing_cat__Industry_Manufacturing | \n",
" prepocessing_cat__Industry_Media & Entertainment | \n",
" prepocessing_cat__Industry_Metals & Mining | \n",
" prepocessing_cat__Industry_Real Estate | \n",
" prepocessing_cat__Industry_Service | \n",
" prepocessing_cat__Industry_Sports | \n",
" prepocessing_cat__Industry_Technology | \n",
" prepocessing_cat__Industry_Telecom | \n",
" prepocessing_cat__Industry_diversified | \n",
"
\n",
" \n",
" \n",
" \n",
" 450 | \n",
" 0.289255 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
1 rows × 860 columns
\n",
"
"
],
"text/plain": [
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
"450 0.289255 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"450 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"450 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"450 1.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"450 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"450 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"450 0.0 \n",
"\n",
"[1 rows x 860 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/plain": [
"'predicted: 3 (proba: [0.00172036 0.04303104 0.02714323 0.36848158 0.19524859 0.2037863\\n 0.1605889 ])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 3'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Подбор гиперпараметров методом поиска по сетке"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 10,\n",
" 'model__max_features': 2,\n",
" 'model__n_estimators': 250}"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=9,\n",
" criterion=\"gini\",\n",
" max_depth=10,\n",
" max_features=2,\n",
" n_estimators=250,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)\n",
"result[\"preds\"] = np.argmax(y_test_probs, axis=1)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"],average=\"macro\")\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"], average=\"macro\")\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"], multi_class=\"ovr\")\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"], average=\"macro\")\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"], average=\"macro\")\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Формирование данных для оценки старой и новой версии модели"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Оценка параметров старой и новой модели"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" Precision_train | \n",
" Precision_test | \n",
" Recall_train | \n",
" Recall_test | \n",
" Accuracy_train | \n",
" Accuracy_test | \n",
" F1_train | \n",
" F1_test | \n",
"
\n",
" \n",
" Name | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Old | \n",
" 0.581578 | \n",
" 0.236539 | \n",
" 0.735419 | \n",
" 0.246556 | \n",
" 0.627404 | \n",
" 0.288462 | \n",
" 0.599765 | \n",
" 0.231541 | \n",
"
\n",
" \n",
" New | \n",
" 0.181388 | \n",
" 0.035714 | \n",
" 0.157692 | \n",
" 0.142857 | \n",
" 0.306250 | \n",
" 0.250000 | \n",
" 0.090702 | \n",
" 0.057143 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
" \n",
" \n",
" | \n",
" Accuracy_test | \n",
" F1_test | \n",
" ROC_AUC_test | \n",
" Cohen_kappa_test | \n",
" MCC_test | \n",
"
\n",
" \n",
" Name | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" Old | \n",
" 0.288462 | \n",
" 0.231541 | \n",
" 0.599541 | \n",
" 0.126828 | \n",
" 0.129917 | \n",
"
\n",
" \n",
" New | \n",
" 0.250000 | \n",
" 0.057143 | \n",
" 0.605446 | \n",
" 0.000000 | \n",
" 0.000000 | \n",
"
\n",
" \n",
"
\n"
],
"text/plain": [
""
]
},
"execution_count": 72,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Under 30\", \"30-40\", \"40-50\", \"50-60\", \"60-70\", \"70-80\", \"80+\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Задача регрессии"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
"y = df['Networth'] # Целевая переменная для регрессии\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" prepocessing_num__Age | \n",
" prepocessing_cat__Country_Argentina | \n",
" prepocessing_cat__Country_Australia | \n",
" prepocessing_cat__Country_Austria | \n",
" prepocessing_cat__Country_Barbados | \n",
" prepocessing_cat__Country_Belgium | \n",
" prepocessing_cat__Country_Belize | \n",
" prepocessing_cat__Country_Brazil | \n",
" prepocessing_cat__Country_Bulgaria | \n",
" prepocessing_cat__Country_Canada | \n",
" ... | \n",
" prepocessing_cat__Industry_Logistics | \n",
" prepocessing_cat__Industry_Manufacturing | \n",
" prepocessing_cat__Industry_Media & Entertainment | \n",
" prepocessing_cat__Industry_Metals & Mining | \n",
" prepocessing_cat__Industry_Real Estate | \n",
" prepocessing_cat__Industry_Service | \n",
" prepocessing_cat__Industry_Sports | \n",
" prepocessing_cat__Industry_Technology | \n",
" prepocessing_cat__Industry_Telecom | \n",
" prepocessing_cat__Industry_diversified | \n",
"
\n",
" \n",
" \n",
" \n",
" 582 | \n",
" -0.109934 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 48 | \n",
" 1.079079 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1772 | \n",
" 1.004766 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 964 | \n",
" -0.407187 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 2213 | \n",
" 1.302019 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 1638 | \n",
" 1.227706 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1095 | \n",
" 0.856139 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1130 | \n",
" 0.781826 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 1294 | \n",
" 0.335946 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" 860 | \n",
" 0.558886 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" ... | \n",
" 0.0 | \n",
" 1.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
"
\n",
"
2080 rows × 855 columns
\n",
"
"
],
"text/plain": [
" prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
"582 -0.109934 0.0 \n",
"48 1.079079 0.0 \n",
"1772 1.004766 0.0 \n",
"964 -0.407187 0.0 \n",
"2213 1.302019 0.0 \n",
"... ... ... \n",
"1638 1.227706 0.0 \n",
"1095 0.856139 0.0 \n",
"1130 0.781826 0.0 \n",
"1294 0.335946 0.0 \n",
"860 0.558886 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 1.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 1.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 1.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"582 ... 0.0 \n",
"48 ... 0.0 \n",
"1772 ... 0.0 \n",
"964 ... 0.0 \n",
"2213 ... 0.0 \n",
"... ... ... \n",
"1638 ... 0.0 \n",
"1095 ... 0.0 \n",
"1130 ... 0.0 \n",
"1294 ... 0.0 \n",
"860 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"582 0.0 \n",
"48 1.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 1.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 1.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"582 1.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 1.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
"[2080 rows x 855 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import pandas as pd\n",
"\n",
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = []\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training LogisticRegression...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n"
]
},
{
"ename": "ValueError",
"evalue": "\nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[7], line 44\u001b[0m\n\u001b[0;32m 42\u001b[0m param_grid \u001b[38;5;241m=\u001b[39m param_grids_classification[name]\n\u001b[0;32m 43\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m RandomizedSearchCV(pipeline, param_grid, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 44\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;66;03m# Лучшая модель\u001b[39;00m\n\u001b[0;32m 47\u001b[0m best_model \u001b[38;5;241m=\u001b[39m grid_search\u001b[38;5;241m.\u001b[39mbest_estimator_\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[1;34m(self, X, y, **params)\u001b[0m\n\u001b[0;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[0;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[0;32m 1015\u001b[0m )\n\u001b[0;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[0;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[0;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1960\u001b[0m, in \u001b[0;36mRandomizedSearchCV._run_search\u001b[1;34m(self, evaluate_candidates)\u001b[0m\n\u001b[0;32m 1958\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[0;32m 1959\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search n_iter candidates from param_distributions\"\"\"\u001b[39;00m\n\u001b[1;32m-> 1960\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1961\u001b[0m \u001b[43m \u001b[49m\u001b[43mParameterSampler\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1962\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_distributions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom_state\u001b[49m\n\u001b[0;32m 1963\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1964\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:996\u001b[0m, in \u001b[0;36mBaseSearchCV.fit..evaluate_candidates\u001b[1;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[0;32m 989\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m!=\u001b[39m n_candidates \u001b[38;5;241m*\u001b[39m n_splits:\n\u001b[0;32m 990\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 991\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv.split and cv.get_n_splits returned \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minconsistent results. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 993\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msplits, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(n_splits, \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m n_candidates)\n\u001b[0;32m 994\u001b[0m )\n\u001b[1;32m--> 996\u001b[0m \u001b[43m_warn_or_raise_about_fit_failures\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;66;03m# For callable self.scoring, the return type is only know after\u001b[39;00m\n\u001b[0;32m 999\u001b[0m \u001b[38;5;66;03m# calling. If the return type is a dictionary, the error scores\u001b[39;00m\n\u001b[0;32m 1000\u001b[0m \u001b[38;5;66;03m# can now be inserted with the correct key. The type checking\u001b[39;00m\n\u001b[0;32m 1001\u001b[0m \u001b[38;5;66;03m# of out will be done in `_insert_error_scores`.\u001b[39;00m\n\u001b[0;32m 1002\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscoring):\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:529\u001b[0m, in \u001b[0;36m_warn_or_raise_about_fit_failures\u001b[1;34m(results, error_score)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_failed_fits \u001b[38;5;241m==\u001b[39m num_fits:\n\u001b[0;32m 523\u001b[0m all_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 524\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mAll the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt is very likely that your model is misconfigured.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 526\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou can try to debug the error by setting error_score=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 528\u001b[0m )\n\u001b[1;32m--> 529\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(all_fits_failed_message)\n\u001b[0;32m 531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 532\u001b[0m some_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 533\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mnum_failed_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed out of a total of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 534\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe score on these train-test partitions for these parameters\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 539\u001b[0m )\n",
"\u001b[1;31mValueError\u001b[0m: \nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
"\n",
"\n",
"# Модели и параметры\n",
"models_classification = {\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"KNN\": KNeighborsClassifier()\n",
"}\n",
"\n",
"param_grids_classification = {\n",
" \"LogisticRegression\": {\n",
" 'model__C': [0.1, 1, 10]\n",
" },\n",
" \"RandomForestClassifier\": {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
" },\n",
" \"KNN\": {\n",
" 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
" 'model__weights': ['uniform', 'distance']\n",
" }\n",
"}\n",
"\n",
"# Результаты\n",
"results_classification = {}\n",
"\n",
"# Перебор моделей\n",
"for name, model in models_classification.items():\n",
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
" \n",
" param_grid = param_grids_classification[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
" grid_search.fit(X_train, y_train)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test)\n",
"\n",
" # Метрики\n",
" acc = accuracy_score(y_test, y_pred)\n",
" f1 = f1_score(y_test, y_pred)\n",
"\n",
" # Вычисление матрицы ошибок\n",
" c_matrix = confusion_matrix(y_test, y_pred)\n",
"\n",
" # Сохранение результатов\n",
" results_classification[name] = {\n",
" \"Best Params\": grid_search.best_params_,\n",
" \"Accuracy\": acc,\n",
" \"F1 Score\": f1,\n",
" \"Confusion_matrix\": c_matrix\n",
" }\n",
"\n",
"# Печать результатов\n",
"for name, metrics in results_classification.items():\n",
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}