{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n" ] } ], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n", "print(df.columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Определим бизнес цели:\n", "## 1- Прогнозирование состояния миллиардера(регрессия)\n", "## 2- Прогнозирование возраста миллиардера(классификация)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Подготовим данные: категоризируем колонку age" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Rank 0\n", "Name 0\n", "Networth 0\n", "Age 0\n", "Country 0\n", "Source 0\n", "Industry 0\n", "dtype: int64\n", "\n", "Rank False\n", "Name False\n", "Networth False\n", "Age False\n", "Country False\n", "Source False\n", "Industry False\n", "dtype: bool\n", "\n" ] } ], "source": [ "print(df.isnull().sum())\n", "\n", "print()\n", "\n", "# Есть ли пустые значения признаков\n", "print(df.isnull().any())\n", "\n", "print()\n", "\n", "# Процент пустых значений признаков\n", "for i in df.columns:\n", " null_rate = df[i].isnull().sum() / len(df) * 100\n", " if null_rate > 0:\n", " print(f\"{i} процент пустых значений: %{null_rate:.2f}\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Rank Name Networth Country \\\n", "0 1 Elon Musk 219.0 United States \n", "1 2 Jeff Bezos 171.0 United States \n", "2 3 Bernard Arnault & family 158.0 France \n", "3 4 Bill Gates 129.0 United States \n", "4 5 Warren Buffett 118.0 United States \n", "\n", " Source Industry Age_category \n", "0 Tesla, SpaceX Automotive 50-60 \n", "1 Amazon Technology 50-60 \n", "2 LVMH Fashion & Retail 70-80 \n", "3 Microsoft Technology 60-70 \n", "4 Berkshire Hathaway Finance & Investments 80+ \n" ] } ], "source": [ "\n", "\n", "bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n", "labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n", "\n", "df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n", "# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n", "df.drop(columns=['Age'], inplace=True)\n", "\n", "# Просмотр результата\n", "print(df.head())" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'X_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RankNameNetworthCountrySourceIndustryAge_category
19091818Tran Ba Duong & family1.6VietnamautomotiveAutomotive60-70
20992076Mark Dixon1.4United Kingdomoffice real estateReal Estate60-70
13921341Yingzhuo Xu2.3ChinaagribusinessFood & Beverage50-60
627622Bruce Flatt4.6Canadamoney managementFinance & Investments50-60
527523Li Liangbin5.2ChinalithiumManufacturing50-60
........................
8485Theo Albrecht, Jr. & family18.7GermanyAldi, Trader Joe'sFashion & Retail70-80
633622Tony Tamer4.6United Statesprivate equityFinance & Investments60-70
922913Bob Gaglardi3.3CanadahotelsReal Estate80+
21782076Eugene Wu1.4TaiwanfinanceFinance & Investments70-80
415411Leonard Stern6.2United Statesreal estateReal Estate80+
\n", "

2080 rows × 7 columns

\n", "
" ], "text/plain": [ " Rank Name Networth Country \\\n", "1909 1818 Tran Ba Duong & family 1.6 Vietnam \n", "2099 2076 Mark Dixon 1.4 United Kingdom \n", "1392 1341 Yingzhuo Xu 2.3 China \n", "627 622 Bruce Flatt 4.6 Canada \n", "527 523 Li Liangbin 5.2 China \n", "... ... ... ... ... \n", "84 85 Theo Albrecht, Jr. & family 18.7 Germany \n", "633 622 Tony Tamer 4.6 United States \n", "922 913 Bob Gaglardi 3.3 Canada \n", "2178 2076 Eugene Wu 1.4 Taiwan \n", "415 411 Leonard Stern 6.2 United States \n", "\n", " Source Industry Age_category \n", "1909 automotive Automotive 60-70 \n", "2099 office real estate Real Estate 60-70 \n", "1392 agribusiness Food & Beverage 50-60 \n", "627 money management Finance & Investments 50-60 \n", "527 lithium Manufacturing 50-60 \n", "... ... ... ... \n", "84 Aldi, Trader Joe's Fashion & Retail 70-80 \n", "633 private equity Finance & Investments 60-70 \n", "922 hotels Real Estate 80+ \n", "2178 finance Finance & Investments 70-80 \n", "415 real estate Real Estate 80+ \n", "\n", "[2080 rows x 7 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Age_category
190960-70
209960-70
139250-60
62750-60
52750-60
......
8470-80
63360-70
92280+
217870-80
41580+
\n", "

2080 rows × 1 columns

\n", "
" ], "text/plain": [ " Age_category\n", "1909 60-70\n", "2099 60-70\n", "1392 50-60\n", "627 50-60\n", "527 50-60\n", "... ...\n", "84 70-80\n", "633 60-70\n", "922 80+\n", "2178 70-80\n", "415 80+\n", "\n", "[2080 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'X_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
RankNameNetworthCountrySourceIndustryAge_category
20752076Radhe Shyam Agarwal1.4Indiaconsumer goodsFashion & Retail70-80
15291513Robert Duggan2.0United StatespharmaceuticalsHealthcare70-80
18031729Yao Kuizhang1.7ChinabeveragesFood & Beverage50-60
425424Alexei Kuzmichev6.0Russiaoil, banking, telecomEnergy50-60
25972578Ramesh Genomal1.0PhilippinesapparelFashion & Retail70-80
........................
935913Alfred Oetker3.3Germanyconsumer goodsFashion & Retail50-60
15411513Thomas Lee2.0United Statesprivate equityFinance & Investments70-80
16461645Roberto Angelini Rossi1.8Chileforestry, miningdiversified70-80
376375Patrick Drahi6.6FrancetelecomTelecom50-60
18941818Gerald Schwartz1.6CanadafinanceFinance & Investments80+
\n", "

520 rows × 7 columns

\n", "
" ], "text/plain": [ " Rank Name Networth Country \\\n", "2075 2076 Radhe Shyam Agarwal 1.4 India \n", "1529 1513 Robert Duggan 2.0 United States \n", "1803 1729 Yao Kuizhang 1.7 China \n", "425 424 Alexei Kuzmichev 6.0 Russia \n", "2597 2578 Ramesh Genomal 1.0 Philippines \n", "... ... ... ... ... \n", "935 913 Alfred Oetker 3.3 Germany \n", "1541 1513 Thomas Lee 2.0 United States \n", "1646 1645 Roberto Angelini Rossi 1.8 Chile \n", "376 375 Patrick Drahi 6.6 France \n", "1894 1818 Gerald Schwartz 1.6 Canada \n", "\n", " Source Industry Age_category \n", "2075 consumer goods Fashion & Retail 70-80 \n", "1529 pharmaceuticals Healthcare 70-80 \n", "1803 beverages Food & Beverage 50-60 \n", "425 oil, banking, telecom Energy 50-60 \n", "2597 apparel Fashion & Retail 70-80 \n", "... ... ... ... \n", "935 consumer goods Fashion & Retail 50-60 \n", "1541 private equity Finance & Investments 70-80 \n", "1646 forestry, mining diversified 70-80 \n", "376 telecom Telecom 50-60 \n", "1894 finance Finance & Investments 80+ \n", "\n", "[520 rows x 7 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Age_category
207570-80
152970-80
180350-60
42550-60
259770-80
......
93550-60
154170-80
164670-80
37650-60
189480+
\n", "

520 rows × 1 columns

\n", "
" ], "text/plain": [ " Age_category\n", "2075 70-80\n", "1529 70-80\n", "1803 50-60\n", "425 50-60\n", "2597 70-80\n", "... ...\n", "935 50-60\n", "1541 70-80\n", "1646 70-80\n", "376 50-60\n", "1894 80+\n", "\n", "[520 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from utils import split_stratified_into_train_val_test\n", "\n", "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", " df, stratify_colname=\"Age_category\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n", ")\n", "\n", "display(\"X_train\", X_train)\n", "display(\"y_train\", y_train)\n", "\n", "display(\"X_test\", X_test)\n", "display(\"y_test\", y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Формирование конвейера для классификации данных\n", "## preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n", "## preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n", "## features_preprocessing -- трансформер для предобработки признаков\n", "## features_engineering -- трансформер для конструирования признаков\n", "## drop_columns -- трансформер для удаления колонок\n", "## pipeline_end -- основной конвейер предобработки данных и конструирования признаков" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
prepocessing_num__Networthprepocessing_cat__Country_Argentinaprepocessing_cat__Country_Australiaprepocessing_cat__Country_Austriaprepocessing_cat__Country_Barbadosprepocessing_cat__Country_Belgiumprepocessing_cat__Country_Belizeprepocessing_cat__Country_Brazilprepocessing_cat__Country_Bulgariaprepocessing_cat__Country_Canada...prepocessing_cat__Industry_Logisticsprepocessing_cat__Industry_Manufacturingprepocessing_cat__Industry_Media & Entertainmentprepocessing_cat__Industry_Metals & Miningprepocessing_cat__Industry_Real Estateprepocessing_cat__Industry_Serviceprepocessing_cat__Industry_Sportsprepocessing_cat__Industry_Technologyprepocessing_cat__Industry_Telecomprepocessing_cat__Industry_diversified
1909-0.3099170.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
2099-0.3292450.00.00.00.00.00.00.00.00.0...0.00.00.00.01.00.00.00.00.00.0
1392-0.2422680.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
627-0.0199950.00.00.00.00.00.00.00.01.0...0.00.00.00.00.00.00.00.00.00.0
5270.0379900.00.00.00.00.00.00.00.00.0...0.01.00.00.00.00.00.00.00.00.0
..................................................................
841.3426370.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
633-0.0199950.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
922-0.1456280.00.00.00.00.00.00.00.01.0...0.00.00.00.01.00.00.00.00.00.0
2178-0.3292450.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4150.1346300.00.00.00.00.00.00.00.00.0...0.00.00.00.01.00.00.00.00.00.0
\n", "

2080 rows × 860 columns

\n", "
" ], "text/plain": [ " prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n", "1909 -0.309917 0.0 \n", "2099 -0.329245 0.0 \n", "1392 -0.242268 0.0 \n", "627 -0.019995 0.0 \n", "527 0.037990 0.0 \n", "... ... ... \n", "84 1.342637 0.0 \n", "633 -0.019995 0.0 \n", "922 -0.145628 0.0 \n", "2178 -0.329245 0.0 \n", "415 0.134630 0.0 \n", "\n", " prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n", "1909 0.0 0.0 \n", "2099 0.0 0.0 \n", "1392 0.0 0.0 \n", "627 0.0 0.0 \n", "527 0.0 0.0 \n", "... ... ... \n", "84 0.0 0.0 \n", "633 0.0 0.0 \n", "922 0.0 0.0 \n", "2178 0.0 0.0 \n", "415 0.0 0.0 \n", "\n", " prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n", "1909 0.0 0.0 \n", "2099 0.0 0.0 \n", "1392 0.0 0.0 \n", "627 0.0 0.0 \n", "527 0.0 0.0 \n", "... ... ... \n", "84 0.0 0.0 \n", "633 0.0 0.0 \n", "922 0.0 0.0 \n", "2178 0.0 0.0 \n", "415 0.0 0.0 \n", "\n", " prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n", "1909 0.0 0.0 \n", "2099 0.0 0.0 \n", "1392 0.0 0.0 \n", "627 0.0 0.0 \n", "527 0.0 0.0 \n", "... ... ... \n", "84 0.0 0.0 \n", "633 0.0 0.0 \n", "922 0.0 0.0 \n", "2178 0.0 0.0 \n", "415 0.0 0.0 \n", "\n", " prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n", "1909 0.0 0.0 \n", "2099 0.0 0.0 \n", "1392 0.0 0.0 \n", "627 0.0 1.0 \n", "527 0.0 0.0 \n", "... ... ... \n", "84 0.0 0.0 \n", "633 0.0 0.0 \n", "922 0.0 1.0 \n", "2178 0.0 0.0 \n", "415 0.0 0.0 \n", "\n", " ... prepocessing_cat__Industry_Logistics \\\n", "1909 ... 0.0 \n", "2099 ... 0.0 \n", "1392 ... 0.0 \n", "627 ... 0.0 \n", "527 ... 0.0 \n", "... ... ... \n", "84 ... 0.0 \n", "633 ... 0.0 \n", "922 ... 0.0 \n", "2178 ... 0.0 \n", "415 ... 0.0 \n", "\n", " prepocessing_cat__Industry_Manufacturing \\\n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 1.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", " prepocessing_cat__Industry_Media & Entertainment \\\n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", " prepocessing_cat__Industry_Metals & Mining \\\n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", " prepocessing_cat__Industry_Real Estate \\\n", "1909 0.0 \n", "2099 1.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 1.0 \n", "2178 0.0 \n", "415 1.0 \n", "\n", " prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n", "1909 0.0 0.0 \n", "2099 0.0 0.0 \n", "1392 0.0 0.0 \n", "627 0.0 0.0 \n", "527 0.0 0.0 \n", "... ... ... \n", "84 0.0 0.0 \n", "633 0.0 0.0 \n", "922 0.0 0.0 \n", "2178 0.0 0.0 \n", "415 0.0 0.0 \n", "\n", " prepocessing_cat__Industry_Technology \\\n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", " prepocessing_cat__Industry_Telecom \\\n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", " prepocessing_cat__Industry_diversified \n", "1909 0.0 \n", "2099 0.0 \n", "1392 0.0 \n", "627 0.0 \n", "527 0.0 \n", "... ... \n", "84 0.0 \n", "633 0.0 \n", "922 0.0 \n", "2178 0.0 \n", "415 0.0 \n", "\n", "[2080 rows x 860 columns]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "from sklearn.compose import ColumnTransformer\n", "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import Pipeline\n", "import pandas as pd\n", "\n", "# Исправляем ColumnTransformer с сохранением имен колонок\n", "columns_to_drop = [\"Age_category\", \"Rank \", \"Name\"]\n", "\n", "num_columns = [\n", " column\n", " for column in X_train.columns\n", " if column not in columns_to_drop and X_train[column].dtype != \"object\"\n", "]\n", "cat_columns = [\n", " column\n", " for column in X_train.columns\n", " if column not in columns_to_drop and X_train[column].dtype == \"object\"\n", "]\n", "\n", "# Предобработка числовых данных\n", "num_imputer = SimpleImputer(strategy=\"median\")\n", "num_scaler = StandardScaler()\n", "preprocessing_num = Pipeline(\n", " [\n", " (\"imputer\", num_imputer),\n", " (\"scaler\", num_scaler),\n", " ]\n", ")\n", "\n", "# Предобработка категориальных данных\n", "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n", "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n", "preprocessing_cat = Pipeline(\n", " [\n", " (\"imputer\", cat_imputer),\n", " (\"encoder\", cat_encoder),\n", " ]\n", ")\n", "\n", "# Общая предобработка признаков\n", "features_preprocessing = ColumnTransformer(\n", " verbose_feature_names_out=True, # Сохраняем имена колонок\n", " transformers=[\n", " (\"prepocessing_num\", preprocessing_num, num_columns),\n", " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n", " ],\n", " remainder=\"drop\" # Убираем неиспользуемые столбцы\n", ")\n", "\n", "# Итоговый конвейер\n", "pipeline_end = Pipeline(\n", " [\n", " (\"features_preprocessing\", features_preprocessing),\n", " ]\n", ")\n", "\n", "# Преобразуем данные\n", "preprocessing_result = pipeline_end.fit_transform(X_train)\n", "\n", "# Создаем DataFrame с правильными именами колонок\n", "preprocessed_df = pd.DataFrame(\n", " preprocessing_result,\n", " columns=pipeline_end.get_feature_names_out(),\n", " index=X_train.index, # Сохраняем индексы\n", ")\n", "\n", "preprocessed_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Формирование набора моделей для классификации\n", "## logistic -- логистическая регрессия\n", "## ridge -- гребневая регрессия\n", "## decision_tree -- дерево решений\n", "## knn -- k-ближайших соседей\n", "## naive_bayes -- наивный Байесовский классификатор\n", "## gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n", "## random_forest -- метод случайного леса (набор деревьев решений)\n", "## mlp -- многослойный персептрон (нейронная сеть)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n", "\n", "class_models = {\n", " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n", " # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n", " \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n", " \"decision_tree\": {\n", " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n", " },\n", " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n", " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n", " \"gradient_boosting\": {\n", " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n", " },\n", " \"random_forest\": {\n", " \"model\": ensemble.RandomForestClassifier(\n", " max_depth=11, class_weight=\"balanced\", random_state=9\n", " )\n", " },\n", " \"mlp\": {\n", " \"model\": neural_network.MLPClassifier(\n", " hidden_layer_sizes=(7,),\n", " max_iter=500,\n", " early_stopping=True,\n", " random_state=9,\n", " )\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Обучение моделей на обучающем наборе данных и оценка на тестовом" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: logistic\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "ename": "ValueError", "evalue": "Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted'].", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[40], line 19\u001b[0m\n\u001b[0;32m 16\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprobs\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m y_test_probs\n\u001b[0;32m 17\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpreds\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m y_test_predict\n\u001b[1;32m---> 19\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision_train\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mmetrics\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_score\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train_predict\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 22\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPrecision_test\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m metrics\u001b[38;5;241m.\u001b[39mprecision_score(\n\u001b[0;32m 23\u001b[0m y_test, y_test_predict\n\u001b[0;32m 24\u001b[0m )\n\u001b[0;32m 25\u001b[0m class_models[model_name][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecall_train\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m metrics\u001b[38;5;241m.\u001b[39mrecall_score(\n\u001b[0;32m 26\u001b[0m y_train, y_train_predict\n\u001b[0;32m 27\u001b[0m )\n", "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py:213\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 207\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 208\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 209\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 210\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 211\u001b[0m )\n\u001b[0;32m 212\u001b[0m ):\n\u001b[1;32m--> 213\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 214\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m InvalidParameterError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m 215\u001b[0m \u001b[38;5;66;03m# When the function is just a wrapper around an estimator, we allow\u001b[39;00m\n\u001b[0;32m 216\u001b[0m \u001b[38;5;66;03m# the function to delegate validation to the estimator, but we replace\u001b[39;00m\n\u001b[0;32m 217\u001b[0m \u001b[38;5;66;03m# the name of the estimator by the name of the function in the error\u001b[39;00m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;66;03m# message to avoid confusion.\u001b[39;00m\n\u001b[0;32m 219\u001b[0m msg \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msub(\n\u001b[0;32m 220\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124mw+ must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 221\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameter of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must be\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 222\u001b[0m \u001b[38;5;28mstr\u001b[39m(e),\n\u001b[0;32m 223\u001b[0m )\n", "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:2204\u001b[0m, in \u001b[0;36mprecision_score\u001b[1;34m(y_true, y_pred, labels, pos_label, average, sample_weight, zero_division)\u001b[0m\n\u001b[0;32m 2037\u001b[0m \u001b[38;5;129m@validate_params\u001b[39m(\n\u001b[0;32m 2038\u001b[0m {\n\u001b[0;32m 2039\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_true\u001b[39m\u001b[38;5;124m\"\u001b[39m: [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marray-like\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msparse matrix\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2064\u001b[0m zero_division\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwarn\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m 2065\u001b[0m ):\n\u001b[0;32m 2066\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute the precision.\u001b[39;00m\n\u001b[0;32m 2067\u001b[0m \n\u001b[0;32m 2068\u001b[0m \u001b[38;5;124;03m The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 2202\u001b[0m \u001b[38;5;124;03m array([0.5, 1. , 1. ])\u001b[39;00m\n\u001b[0;32m 2203\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m-> 2204\u001b[0m p, _, _, _ \u001b[38;5;241m=\u001b[39m \u001b[43mprecision_recall_fscore_support\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 2205\u001b[0m \u001b[43m \u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2206\u001b[0m \u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2207\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2208\u001b[0m \u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpos_label\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2209\u001b[0m \u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2210\u001b[0m \u001b[43m \u001b[49m\u001b[43mwarn_for\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mprecision\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2211\u001b[0m \u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2212\u001b[0m \u001b[43m \u001b[49m\u001b[43mzero_division\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mzero_division\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 2213\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 2214\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m p\n", "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py:186\u001b[0m, in \u001b[0;36mvalidate_params..decorator..wrapper\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 184\u001b[0m global_skip_validation \u001b[38;5;241m=\u001b[39m get_config()[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mskip_parameter_validation\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m 185\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m global_skip_validation:\n\u001b[1;32m--> 186\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 188\u001b[0m func_sig \u001b[38;5;241m=\u001b[39m signature(func)\n\u001b[0;32m 190\u001b[0m \u001b[38;5;66;03m# Map *args/**kwargs to the function signature\u001b[39;00m\n", "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1789\u001b[0m, in \u001b[0;36mprecision_recall_fscore_support\u001b[1;34m(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight, zero_division)\u001b[0m\n\u001b[0;32m 1626\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Compute precision, recall, F-measure and support for each class.\u001b[39;00m\n\u001b[0;32m 1627\u001b[0m \n\u001b[0;32m 1628\u001b[0m \u001b[38;5;124;03mThe precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of\u001b[39;00m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1786\u001b[0m \u001b[38;5;124;03m array([2, 2, 2]))\u001b[39;00m\n\u001b[0;32m 1787\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 1788\u001b[0m _check_zero_division(zero_division)\n\u001b[1;32m-> 1789\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[43m_check_set_wise_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43my_true\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_pred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maverage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos_label\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1791\u001b[0m \u001b[38;5;66;03m# Calculate tp_sum, pred_sum, true_sum ###\u001b[39;00m\n\u001b[0;32m 1792\u001b[0m samplewise \u001b[38;5;241m=\u001b[39m average \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1578\u001b[0m, in \u001b[0;36m_check_set_wise_labels\u001b[1;34m(y_true, y_pred, average, labels, pos_label)\u001b[0m\n\u001b[0;32m 1576\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m y_type \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmulticlass\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m 1577\u001b[0m average_options\u001b[38;5;241m.\u001b[39mremove(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msamples\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m-> 1578\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 1579\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTarget is \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m but average=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m. Please \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1580\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mchoose another average setting, one of \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m (y_type, average_options)\n\u001b[0;32m 1581\u001b[0m )\n\u001b[0;32m 1582\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m pos_label \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m1\u001b[39m):\n\u001b[0;32m 1583\u001b[0m warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[0;32m 1584\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNote that pos_label (set to \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m) is ignored when \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 1585\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124maverage != \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbinary\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m (got \u001b[39m\u001b[38;5;132;01m%r\u001b[39;00m\u001b[38;5;124m). You may use \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 1588\u001b[0m \u001b[38;5;167;01mUserWarning\u001b[39;00m,\n\u001b[0;32m 1589\u001b[0m )\n", "\u001b[1;31mValueError\u001b[0m: Target is multiclass but average='binary'. Please choose another average setting, one of [None, 'micro', 'macro', 'weighted']." ] } ], "source": [ "import numpy as np\n", "from sklearn import metrics\n", "\n", "for model_name in class_models.keys():\n", " print(f\"Model: {model_name}\")\n", " model = class_models[model_name][\"model\"]\n", "\n", " model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n", " model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n", "\n", " y_train_predict = model_pipeline.predict(X_train)\n", " y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n", " y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n", "\n", " class_models[model_name][\"pipeline\"] = model_pipeline\n", " class_models[model_name][\"probs\"] = y_test_probs\n", " class_models[model_name][\"preds\"] = y_test_predict\n", "\n", " class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n", " y_test, y_test_probs\n", " )\n", " class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n", " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n", " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n", " y_test, y_test_predict\n", " )" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.6" } }, "nbformat": 4, "nbformat_minor": 2 }