1906 lines
224 KiB
Plaintext
Raw Normal View History

2024-11-15 22:35:48 +04:00
{
"cells": [
{
"cell_type": "code",
2024-11-29 00:30:06 +04:00
"execution_count": null,
2024-11-15 22:35:48 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Rank ', 'Name', 'Networth', 'Age', 'Country', 'Source', 'Industry'], dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Определим бизнес цели:\n",
2024-11-29 02:58:31 +04:00
"## 1- Прогнозирование состояния миллиардера(регрессия)\n",
"## 2- Прогнозирование возраста миллиардера(классификация)\n"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-29 02:58:31 +04:00
"# Проверим данные на пустые значения"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-15 22:37:33 +04:00
"Rank 0\n",
"Name 0\n",
"Networth 0\n",
"Age 0\n",
"Country 0\n",
"Source 0\n",
"Industry 0\n",
"dtype: int64\n",
"\n",
"Rank False\n",
"Name False\n",
"Networth False\n",
"Age False\n",
"Country False\n",
"Source False\n",
"Industry False\n",
"dtype: bool\n",
"\n"
2024-11-15 22:35:48 +04:00
]
}
],
"source": [
2024-11-15 22:37:33 +04:00
"print(df.isnull().sum())\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-15 22:37:33 +04:00
"print()\n",
"\n",
"# Есть ли пустые значения признаков\n",
"print(df.isnull().any())\n",
"\n",
"print()\n",
"\n",
"# Процент пустых значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f\"{i} процент пустых значений: %{null_rate:.2f}\")"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
2024-11-29 02:58:31 +04:00
"execution_count": 51,
2024-11-15 22:35:48 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-29 02:58:31 +04:00
"Collecting Jinja2\n",
" Downloading jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)\n",
"Collecting MarkupSafe>=2.0 (from Jinja2)\n",
" Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl.metadata (4.1 kB)\n",
"Downloading jinja2-3.1.4-py3-none-any.whl (133 kB)\n",
"Downloading MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl (15 kB)\n",
"Installing collected packages: MarkupSafe, Jinja2\n",
"Successfully installed Jinja2-3.1.4 MarkupSafe-3.0.2\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-11-15 22:35:48 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"[notice] A new release of pip is available: 24.2 -> 24.3.1\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
2024-11-15 22:35:48 +04:00
]
}
],
"source": [
2024-11-29 02:58:31 +04:00
"pip install Jinja2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Задача регрессии"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Создадим выборки"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
"y = df['Networth'] # Целевая переменная для регрессии\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование конвейера для классификации данных\n",
"## preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"## preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"## features_preprocessing -- трансформер для предобработки признаков\n",
"## features_engineering -- трансформер для конструирования признаков\n",
"## drop_columns -- трансформер для удаления колонок\n",
"## pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
2024-11-15 22:37:33 +04:00
]
},
{
"cell_type": "code",
2024-11-29 02:58:31 +04:00
"execution_count": 2,
2024-11-15 22:37:33 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
2024-11-29 02:58:31 +04:00
" <th>prepocessing_num__Age</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>582</th>\n",
" <td>-0.109934</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>48</th>\n",
" <td>1.079079</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1772</th>\n",
" <td>1.004766</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>964</th>\n",
" <td>-0.407187</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>2213</th>\n",
" <td>1.302019</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2024-11-29 02:58:31 +04:00
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1638</th>\n",
" <td>1.227706</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" <td>...</td>\n",
2024-11-29 02:58:31 +04:00
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" <td>1.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1095</th>\n",
" <td>0.856139</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" <td>...</td>\n",
2024-11-29 02:58:31 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1130</th>\n",
" <td>0.781826</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1294</th>\n",
" <td>0.335946</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>860</th>\n",
" <td>0.558886</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-15 22:37:33 +04:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-11-29 02:58:31 +04:00
"<p>2080 rows × 855 columns</p>\n",
2024-11-15 22:37:33 +04:00
"</div>"
],
"text/plain": [
2024-11-29 02:58:31 +04:00
" prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
"582 -0.109934 0.0 \n",
"48 1.079079 0.0 \n",
"1772 1.004766 0.0 \n",
"964 -0.407187 0.0 \n",
"2213 1.302019 0.0 \n",
"... ... ... \n",
"1638 1.227706 0.0 \n",
"1095 0.856139 0.0 \n",
"1130 0.781826 0.0 \n",
"1294 0.335946 0.0 \n",
"860 0.558886 0.0 \n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-29 02:58:31 +04:00
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 1.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-29 02:58:31 +04:00
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-29 02:58:31 +04:00
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 1.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 1.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-29 02:58:31 +04:00
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"582 ... 0.0 \n",
"48 ... 0.0 \n",
"1772 ... 0.0 \n",
"964 ... 0.0 \n",
"2213 ... 0.0 \n",
"... ... ... \n",
"1638 ... 0.0 \n",
"1095 ... 0.0 \n",
"1130 ... 0.0 \n",
"1294 ... 0.0 \n",
"860 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"582 0.0 \n",
"48 1.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 1.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 1.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"582 1.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 1.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
2024-11-15 22:37:33 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"[2080 rows x 855 columns]"
2024-11-15 22:37:33 +04:00
]
},
2024-11-29 02:58:31 +04:00
"execution_count": 2,
2024-11-15 22:37:33 +04:00
"metadata": {},
2024-11-29 02:58:31 +04:00
"output_type": "execute_result"
2024-11-15 22:37:33 +04:00
}
],
"source": [
2024-11-29 02:58:31 +04:00
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import pandas as pd\n",
"\n",
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = []\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
"]\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
2024-11-15 22:37:33 +04:00
")\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
2024-11-15 22:35:48 +04:00
"\n",
2024-11-29 02:58:31 +04:00
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
2024-11-15 22:35:48 +04:00
"source": [
2024-11-29 02:58:31 +04:00
"# Формирование набора моделей\n",
"## LinearRegression -- логистическая регрессия\n",
"## RandomForestRegressor -- метод случайного леса (набор деревьев решений)\n",
"## GradientBoostingRegressor -- метод градиентного бустинга (набор деревьев решений)\n",
"# Обучение этих моделей с применением RandomizedSearchCV(для подбора гиперпараметров)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training LinearRegression...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 1 is smaller than n_iter=10. Running 1 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training RandomForestRegressor...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training GradientBoostingRegressor...\n",
"\n",
"Model: LinearRegression\n",
"Best Params: {}\n",
"MAE: 18059903.80176681\n",
"RMSE: 411829080.6584508\n",
"R2: -7135788186375614.0\n",
"\n",
"Model: RandomForestRegressor\n",
"Best Params: {'model__n_estimators': 40, 'model__max_depth': 10}\n",
"MAE: 3.454630023161808\n",
"RMSE: 7.755775760541111\n",
"R2: -1.530803448377045\n",
"\n",
"Model: GradientBoostingRegressor\n",
"Best Params: {'model__n_estimators': 100, 'model__max_depth': 4, 'model__learning_rate': 0.4}\n",
"MAE: 3.585784679817764\n",
"RMSE: 10.312249036012052\n",
"R2: -3.474193004771121\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.model_selection import GridSearchCV, RandomizedSearchCV\n",
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
"import matplotlib.pyplot as plt\n",
"\n",
"random_state = 42\n",
"\n",
"# Модели и параметры\n",
"models_regression = {\n",
" \"LinearRegression\": LinearRegression(),\n",
" \"RandomForestRegressor\": RandomForestRegressor(random_state=random_state),\n",
" \"GradientBoostingRegressor\": GradientBoostingRegressor(random_state=random_state)\n",
"}\n",
"\n",
"param_grids_regression = {\n",
" \"LinearRegression\": {},\n",
" \"RandomForestRegressor\": {\n",
" 'model__n_estimators': [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" 'model__max_depth': [None, 2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" },\n",
" \"GradientBoostingRegressor\": {\n",
" 'model__n_estimators': [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5],\n",
" 'model__max_depth': [2, 3, 4, 5, 6, 7, 8, 9 ,10]\n",
" }\n",
"}\n",
"\n",
"# Результаты\n",
"results_regression = {}\n",
"\n",
"# Перебор моделей\n",
"for name, model in models_regression.items():\n",
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
" param_grid = param_grids_regression[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='neg_mean_absolute_error', n_jobs=-1)\n",
" grid_search.fit(X_train, y_train)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test)\n",
"\n",
" # Метрики\n",
" mae = mean_absolute_error(y_test, y_pred)\n",
" rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n",
" r2 = r2_score(y_test, y_pred)\n",
"\n",
" # Сохранение результатов\n",
" results_regression[name] = {\n",
" \"Best Params\": grid_search.best_params_,\n",
" \"MAE\": mae,\n",
" \"RMSE\": rmse,\n",
" \"R2\": r2\n",
" }\n",
"\n",
"# Печать результатов\n",
"for name, metrics in results_regression.items():\n",
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
2024-11-15 22:35:48 +04:00
]
},
{
"cell_type": "code",
2024-11-29 02:58:31 +04:00
"execution_count": 14,
2024-11-15 22:37:33 +04:00
"metadata": {},
2024-11-15 23:33:34 +04:00
"outputs": [
{
"data": {
"text/html": [
2024-11-29 02:58:31 +04:00
"<style type=\"text/css\">\n",
"#T_5e893_row0_col0, #T_5e893_row0_col1, #T_5e893_row1_col0, #T_5e893_row1_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_5e893_row0_col2, #T_5e893_row1_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_5e893_row2_col0, #T_5e893_row2_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_5e893_row2_col2 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
2024-11-15 23:33:34 +04:00
"</style>\n",
2024-11-29 02:58:31 +04:00
"<table id=\"T_5e893\">\n",
2024-11-15 23:33:34 +04:00
" <thead>\n",
2024-11-29 02:58:31 +04:00
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_5e893_level0_col0\" class=\"col_heading level0 col0\" >MAE</th>\n",
" <th id=\"T_5e893_level0_col1\" class=\"col_heading level0 col1\" >RMSE</th>\n",
" <th id=\"T_5e893_level0_col2\" class=\"col_heading level0 col2\" >R2</th>\n",
2024-11-15 23:33:34 +04:00
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th id=\"T_5e893_level0_row0\" class=\"row_heading level0 row0\" >RandomForestRegressor</th>\n",
" <td id=\"T_5e893_row0_col0\" class=\"data row0 col0\" >3.454630</td>\n",
" <td id=\"T_5e893_row0_col1\" class=\"data row0 col1\" >7.755776</td>\n",
" <td id=\"T_5e893_row0_col2\" class=\"data row0 col2\" >-1.530803</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_5e893_level0_row1\" class=\"row_heading level0 row1\" >GradientBoostingRegressor</th>\n",
" <td id=\"T_5e893_row1_col0\" class=\"data row1 col0\" >3.585785</td>\n",
" <td id=\"T_5e893_row1_col1\" class=\"data row1 col1\" >10.312249</td>\n",
" <td id=\"T_5e893_row1_col2\" class=\"data row1 col2\" >-3.474193</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_5e893_level0_row2\" class=\"row_heading level0 row2\" >LinearRegression</th>\n",
" <td id=\"T_5e893_row2_col0\" class=\"data row2 col0\" >18059903.801767</td>\n",
" <td id=\"T_5e893_row2_col1\" class=\"data row2 col1\" >411829080.658451</td>\n",
" <td id=\"T_5e893_row2_col2\" class=\"data row2 col2\" >-7135788186375614.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x182a6043ef0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Импортируем pandas для работы с таблицами\n",
"import pandas as pd\n",
"\n",
"# Формируем таблицу метрик\n",
"reg_metrics = pd.DataFrame.from_dict(results_regression, orient=\"index\")[\n",
" [\"MAE\", \"RMSE\", \"R2\"]\n",
"]\n",
"\n",
"# Визуализация результатов с помощью стилизации\n",
"styled_metrics = (\n",
" reg_metrics.sort_values(by=\"RMSE\")\n",
" .style.background_gradient(cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE\", \"MAE\"])\n",
" .background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"R2\"])\n",
")\n",
"\n",
"styled_metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Шикарный вывод: по стране, возрасту, сфере деятельности и источнику доходов невозможно предсказать состояние человека. Значит ли это, что кто угодно, где угодно, и в чём угодно может добиться успеха?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Классификация"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Категоризируем колонку возраста миллиардеров"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Rank Name Networth Country \\\n",
"0 1 Elon Musk 219.0 United States \n",
"1 2 Jeff Bezos 171.0 United States \n",
"2 3 Bernard Arnault & family 158.0 France \n",
"3 4 Bill Gates 129.0 United States \n",
"4 5 Warren Buffett 118.0 United States \n",
"\n",
" Source Industry Age_category \n",
"0 Tesla, SpaceX Automotive 50-60 \n",
"1 Amazon Technology 50-60 \n",
"2 LVMH Fashion & Retail 70-80 \n",
"3 Microsoft Technology 60-70 \n",
"4 Berkshire Hathaway Finance & Investments 80+ \n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"df = pd.read_csv(\"C://Users//annal//aim//static//csv//Forbes_Billionaires.csv\")\n",
"\n",
"bins = [0, 30, 40, 50, 60, 70, 80, 101] # границы для возрастных категорий\n",
"labels = ['Under 30', '30-40', '40-50', '50-60', '60-70', '70-80', '80+'] # метки для категорий\n",
"\n",
"df[\"Age_category\"] = pd.cut(df['Age'], bins=bins, labels=labels, right=False)\n",
"# Удаляем оригинальные колонки 'country', 'industry' и 'source' из исходного DataFrame\n",
"df.drop(columns=['Age'], inplace=True)\n",
"\n",
"# Просмотр результата\n",
"print(df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Создадим выборки"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X = df.drop(columns=['Age_category','Rank ', 'Name']) # Признаки\n",
"# Целевая переменная для классификации\n",
"y_class = df['Age_category'] \n",
"\n",
"# Разделение данных\n",
"X_train_clf, X_test_clf, y_train_clf, y_test_clf = train_test_split(X, y_class, test_size=0.2, random_state=42)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Вновь запустим конвейер"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prepocessing_num__Networth</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>582</th>\n",
" <td>-0.013606</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>48</th>\n",
" <td>1.994083</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1772</th>\n",
" <td>-0.288162</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>964</th>\n",
" <td>-0.159464</td>\n",
" <td>0.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>2213</th>\n",
" <td>-0.322481</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>0.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1638</th>\n",
" <td>-0.271002</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1095</th>\n",
" <td>-0.193783</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1130</th>\n",
" <td>-0.193783</td>\n",
" <td>0.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>1294</th>\n",
" <td>-0.228103</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
2024-11-29 02:58:31 +04:00
" <th>860</th>\n",
" <td>-0.133724</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
2024-11-29 02:58:31 +04:00
" <td>1.0</td>\n",
2024-11-15 23:33:34 +04:00
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
2024-11-29 02:58:31 +04:00
"<p>2080 rows × 855 columns</p>\n",
2024-11-15 23:33:34 +04:00
"</div>"
],
"text/plain": [
" prepocessing_num__Networth prepocessing_cat__Country_Argentina \\\n",
2024-11-29 02:58:31 +04:00
"582 -0.013606 0.0 \n",
"48 1.994083 0.0 \n",
"1772 -0.288162 0.0 \n",
"964 -0.159464 0.0 \n",
"2213 -0.322481 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... ... \n",
2024-11-29 02:58:31 +04:00
"1638 -0.271002 0.0 \n",
"1095 -0.193783 0.0 \n",
"1130 -0.193783 0.0 \n",
"1294 -0.228103 0.0 \n",
"860 -0.133724 0.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 1.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... ... \n",
2024-11-29 02:58:31 +04:00
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... ... \n",
2024-11-29 02:58:31 +04:00
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 1.0 \n",
2024-11-15 23:33:34 +04:00
"... ... ... \n",
2024-11-29 02:58:31 +04:00
"1638 0.0 0.0 \n",
"1095 0.0 1.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
2024-11-15 23:33:34 +04:00
" ... prepocessing_cat__Industry_Logistics \\\n",
2024-11-29 02:58:31 +04:00
"582 ... 0.0 \n",
"48 ... 0.0 \n",
"1772 ... 0.0 \n",
"964 ... 0.0 \n",
"2213 ... 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... ... \n",
2024-11-29 02:58:31 +04:00
"1638 ... 0.0 \n",
"1095 ... 0.0 \n",
"1130 ... 0.0 \n",
"1294 ... 0.0 \n",
"860 ... 0.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 \n",
"48 1.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... \n",
2024-11-29 02:58:31 +04:00
"1638 1.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 1.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... \n",
2024-11-29 02:58:31 +04:00
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
2024-11-15 23:33:34 +04:00
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
2024-11-29 02:58:31 +04:00
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
2024-11-15 23:33:34 +04:00
"... ... \n",
2024-11-29 02:58:31 +04:00
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"582 1.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 1.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
"[2080 rows x 855 columns]"
2024-11-29 00:53:22 +04:00
]
},
2024-11-29 02:58:31 +04:00
"execution_count": 6,
2024-11-29 00:53:22 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2024-11-29 02:58:31 +04:00
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
2024-11-29 00:53:22 +04:00
"import pandas as pd\n",
"\n",
2024-11-29 02:58:31 +04:00
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = []\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train_clf.columns\n",
" if column not in columns_to_drop and X_train_clf[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train_clf.columns\n",
" if column not in columns_to_drop and X_train_clf[column].dtype == \"object\"\n",
2024-11-29 00:53:22 +04:00
"]\n",
"\n",
2024-11-29 02:58:31 +04:00
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
2024-11-29 00:53:22 +04:00
")\n",
"\n",
2024-11-29 02:58:31 +04:00
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train_clf)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train_clf.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
2024-11-29 01:37:09 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2024-11-29 02:58:31 +04:00
"# Формирование набора моделей\n",
"## LogisticRegression -- логистическая регрессия\n",
"## RandomForestClassifier -- метод случайного леса (набор деревьев решений)\n",
"## KNN -- k-ближайших соседей\n",
"# Обучение этих моделей с применением RandomizedSearchCV(для подбора гиперпараметров)"
2024-11-29 01:37:09 +04:00
]
},
{
"cell_type": "code",
2024-11-29 02:58:31 +04:00
"execution_count": 10,
2024-11-29 01:37:09 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-29 02:58:31 +04:00
"Training LogisticRegression...\n"
2024-11-29 01:37:09 +04:00
]
2024-11-29 02:58:31 +04:00
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [nan nan nan]\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
]
},
2024-11-29 01:37:09 +04:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-11-29 02:58:31 +04:00
"Training RandomForestClassifier...\n"
2024-11-29 01:37:09 +04:00
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
2024-11-29 02:58:31 +04:00
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan]\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
2024-11-29 01:37:09 +04:00
" warnings.warn(\n"
]
},
{
2024-11-29 02:58:31 +04:00
"name": "stdout",
"output_type": "stream",
"text": [
"Training KNN...\n",
"\n",
"Model: LogisticRegression\n",
"Best Params: {'model__C': 0.1}\n",
"Accuracy: 0.3903846153846154\n",
"F1 Score: 0.20313635491500218\n",
"Confusion_matrix: [[ 0 1 2 6 1 0 0]\n",
" [ 0 1 27 18 7 0 0]\n",
" [ 0 1 82 35 13 3 0]\n",
" [ 0 1 45 80 34 4 0]\n",
" [ 0 0 15 51 37 4 0]\n",
" [ 0 0 5 28 14 3 0]\n",
" [ 0 0 0 2 0 0 0]]\n",
"\n",
"Model: RandomForestClassifier\n",
"Best Params: {'model__n_estimators': 200, 'model__max_features': 'sqrt', 'model__max_depth': 7, 'model__criterion': 'gini', 'model__class_weight': 'balanced'}\n",
"Accuracy: 0.29615384615384616\n",
"F1 Score: 0.23917948939202166\n",
"Confusion_matrix: [[ 2 3 1 1 0 1 2]\n",
" [ 1 21 11 4 2 14 0]\n",
" [ 1 18 65 7 12 31 0]\n",
" [ 2 23 35 12 20 70 2]\n",
" [ 1 4 12 3 20 65 2]\n",
" [ 0 5 1 5 5 34 0]\n",
" [ 1 0 0 1 0 0 0]]\n",
"\n",
"Model: KNN\n",
"Best Params: {'model__weights': 'uniform', 'model__n_neighbors': 3}\n",
"Accuracy: 0.32884615384615384\n",
"F1 Score: 0.23870853259159636\n",
"Confusion_matrix: [[ 3 0 4 2 1 0 0]\n",
" [ 4 19 13 10 6 1 0]\n",
" [ 8 14 65 27 15 5 0]\n",
" [ 9 14 49 53 29 10 0]\n",
" [ 8 8 28 25 24 14 0]\n",
" [ 0 4 9 18 12 7 0]\n",
" [ 1 0 0 1 0 0 0]]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan]\n",
" warnings.warn(\n",
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0, 1] during transform. These unknown categories will be encoded as all zeros\n",
" warnings.warn(\n"
2024-11-29 01:37:09 +04:00
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
2024-11-29 02:58:31 +04:00
"from sklearn.model_selection import GridSearchCV, RandomizedSearchCV\n",
2024-11-29 01:37:09 +04:00
"from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
2024-11-29 02:58:31 +04:00
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import Pipeline\n",
2024-11-29 01:37:09 +04:00
"\n",
"\n",
"# Модели и параметры\n",
"models_classification = {\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"KNN\": KNeighborsClassifier()\n",
"}\n",
"\n",
"param_grids_classification = {\n",
" \"LogisticRegression\": {\n",
" 'model__C': [0.1, 1, 10]\n",
" },\n",
" \"RandomForestClassifier\": {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
2024-11-29 02:58:31 +04:00
" \"model__class_weight\": [\"balanced\"]\n",
2024-11-29 01:37:09 +04:00
" },\n",
" \"KNN\": {\n",
" 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
" 'model__weights': ['uniform', 'distance']\n",
" }\n",
"}\n",
"\n",
"# Результаты\n",
"results_classification = {}\n",
"\n",
"# Перебор моделей\n",
"for name, model in models_classification.items():\n",
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
" param_grid = param_grids_classification[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
" grid_search.fit(X_train_clf, y_train_clf)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test_clf)\n",
"\n",
" # Метрики\n",
" acc = accuracy_score(y_test_clf, y_pred)\n",
2024-11-29 02:58:31 +04:00
" f1 = f1_score(y_test_clf, y_pred, average=\"macro\")\n",
2024-11-29 01:37:09 +04:00
"\n",
" # Вычисление матрицы ошибок\n",
" c_matrix = confusion_matrix(y_test_clf, y_pred)\n",
"\n",
" # Сохранение результатов\n",
" results_classification[name] = {\n",
" \"Best Params\": grid_search.best_params_,\n",
" \"Accuracy\": acc,\n",
" \"F1 Score\": f1,\n",
" \"Confusion_matrix\": c_matrix\n",
" }\n",
"\n",
"# Печать результатов\n",
"for name, metrics in results_classification.items():\n",
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
]
2024-11-29 02:58:31 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Покажем матрицы в виде диаграмм"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABaEAAAbVCAYAAAAtZQkZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU5dfG8XvTAyEJhJDQO4QqAlIEBBEICCgSQBCVIooUKSoqP0GKAoIoqBRFEbAgiiKCSBdQmkhTKdI7JJSQBAipu+8febO6JIS6mdnk+7muuXRnZnfODpvds2fPPI/FZrPZBAAAAAAAAACAE7gZHQAAAAAAAAAAIOeiCA0AAAAAAAAAcBqK0AAAAAAAAAAAp6EIDQAAAAAAAABwGorQAAAAAAAAAACnoQgNAAAAAAAAAHAaitAAAAAAAAAAAKehCA0AAAAAAAAAcBqK0AAAAAAAAAAAp6EIDQAAAAAAAABwGorQAAAAAAAAAHAXlSpVShaLJcPSr18/SVJCQoL69eunoKAg+fn5KSIiQlFRUQZH7TwWm81mMzoIAAAAAAAAAMgpzp07p9TUVPvtXbt2qXnz5lqzZo2aNGmiPn36aMmSJZo9e7YCAgLUv39/ubm5acOGDQZG7TwUoQEAAAAAAADAiQYNGqSffvpJBw4cUFxcnIKDgzV37lx16NBBkvTPP/+oUqVK2rRpk+rVq2dwtHefh9EBAAAAwLkSEhKUlJRkdBh2Xl5e8vHxMToMAAAAuAAz5bI2m00Wi8Vhnbe3t7y9vbO8X1JSkr788ku9+OKLslgs2rZtm5KTk9WsWTP7PmFhYSpRogRFaAAAALiehIQElS7pp8izqTfeOZuEhobqyJEjFKIBAACQJbPlsn5+frp8+bLDuhEjRmjkyJFZ3m/hwoWKiYlR9+7dJUmRkZHy8vJSYGCgw34hISGKjIy8ixGbB0VoAACAHCwpKUmRZ1N1bFsp+eczfk7quEtWlax1VElJSRShAQAAkCUz5bLpeeyJEyfk7+9vX3+jLmhJmjlzplq1aqUiRYo4M0RTowgNAACQC/jnc5N/PnejwwAAAABumZlyWX9/f4ci9I0cO3ZMq1at0oIFC+zrQkNDlZSUpJiYGIdu6KioKIWGht7NcE2DIjQAAEAuYJVNVlmNDkNWMSc2AAAAbo0ZctnbzWNnzZqlQoUKqXXr1vZ1tWrVkqenp1avXq2IiAhJ0r59+3T8+HHVr1//rsRrNhShAQAAAAAAAOAus1qtmjVrlrp16yYPj3/LsAEBAXrmmWf04osvqkCBAvL399cLL7yg+vXr58hJCSWK0AAAAAAAAABw161atUrHjx9Xz549M2ybNGmS3NzcFBERocTERIWHh2vatGkGRJk9LDabjWsiAQAAcqi4uDgFBATo7L6Shk/mIqVN6FKo4jHFxsbe0lh6AAAAyH3MlMuSx94Z47+JAAAAAAAAAAByLIrQAAAAAAAAAACnYUxoAACAXCBtRnHjR2EzQwwAAABwLWbIZY0+vqujExoAAAAAAAAA4DR0QgMAAOQCVlllNToIySRRAAAAwJWYIZc1PgLXRic0AAAAAAAAAMBpKEIDAADAtFJTUzV8+HCVLl1avr6+Klu2rN58803ZbP+OyWez2fTGG2+ocOHC8vX1VbNmzXTgwAEDowYAAADwXwzHAQAAkAuk2mxKtRk/mcqtxjB+/HhNnz5dc+bMUZUqVbR161b16NFDAQEBGjBggCRpwoQJ+uCDDzRnzhyVLl1aw4cPV3h4uPbs2SMfHx9nPA0AAABkIzPkskYf39VRhAYAAIBpbdy4UY8++qhat24tSSpVqpS+/vprbdmyRVJaF/TkyZM1bNgwPfroo5Kkzz//XCEhIVq4cKE6d+5sWOwAAAAA0jAcBwAAAEzr/vvv1+rVq7V//35J0p9//qn169erVatWkqQjR44oMjJSzZo1s98nICBAdevW1aZNmwyJGQAAAIAjOqEBAAByAatsssr4SwjTY4iLi3NY7+3tLW9v7wz7v/baa4qLi1NYWJjc3d2VmpqqMWPGqGvXrpKkyMhISVJISIjD/UJCQuzbAAAA4NrMkMsafXxXRyc0AAAAsl3x4sUVEBBgX8aNG5fpft9++62++uorzZ07V9u3b9ecOXM0ceJEzZkzJ5sjBgAAAHC76IQGAABAtjtx4oT8/f3ttzPrgpakIUOG6LXXXrOP7VytWjUdO3ZM48aNU7du3RQaGipJioqKUuHChe33i4qKUo0aNZz3BAAAAADcNIrQAAAAuYBVNqWa4BLC9MsY/f39HYrQ1xMfHy83N8eL99zd3WW1WiVJpUuXVmhoqFavXm0vOsfFxen3339Xnz597m7wAAAAMIQZclmG47gzFKEBAABgWm3bttWYMWNUokQJValSRTt27NB7772nnj17SpIsFosGDRqkt956S+XLl1fp0qU1fPhwFSlSRO3atTM2eAAAAACSKEIDAADkCmaYzCU9jlvx4Ycfavjw4erbt6/Onj2rIkWKqHfv3nrjjTfs+7zyyiu6cuWKnnvuOcXExKhhw4ZatmyZfHx87nb4AAAAMIAZclmjj+/qLDabjTMIAACQQ8XFxSkgIECH/glVvnzGz0l96ZJVZcMiFRsbe1PDcQAAACD3MlMuSx57Z4z/JgIAAAAAAAAAyLEYjgMAACAXSLXZlGqCC+DMEAMAAABcixlyWaOP7+rohAYAAAAAAAAAOA1FaAAAAAAAAACA0zAcBwAAQC5g/f/FaGaIAQAAAK7FDLms0cd3dXRCAwAAAAAAAACchiI0AAAAAAAAAMBpGI4DAAAgF0iVTakyfkZvM8QAAAAA12KGXNbo47s6OqEBAAAAAAAAAE5DJzQAAEAukGpLW4xmhhgAAADgWsyQyxp9fFdHJzQAAAAAAAAAwGkoQgMAAAAAAAAAnIbhOAAAAHIB6/8vRjNDDAAAAHAtZshljT6+q6MTGgAAAAAAAADgNBShAQAAAAAAAABOw3AcAAAAuYBVFqXKYnQYspogBgAAALgWM+Sy5LF3hk5oAAAAAAAAAIDTUIQGAAAAAAAAADgNw3EAAADkAlZb2mI0M8QAAAAA12KGXNbo47s6OqEBAAAAAAAAAE5DERoAAAAAAAAA4DQMxwEAAJALpJpgRvH0OAAAAIBbYYZc1ujjuzo6oQEAAAAAAAAATkMnNAAAQC5ghu6R9DgAAACAW2GGXNbo47s6OqEBAAAAAAAAAE5DERoAAAAAAAAA4DQMxwEAAJALWG0WWW3GX0JohhgAAADgWsyQyxp9fFdHJzQAAAAAAAAAwGkoQgMAAAAAAAAAnIbhOAAAAHIBM8wonh4HAAAAcCvMkMsafXxXRyc0AAAAAAAAAMBpKEIDAAAAAAAAAJyG4TgAAABygVS5KdUE/QepRgcAAAAAl2OGXJY89s4Y/00EAAAAAAAAAJBj0QkNAACQC9hsFlltxk+mYjNBDAAAAHAtZshlyWPvDJ3QAAAAAAAAAACnoQgNAAAAAAAAAHAahuMAAADIBVJlUaqMv4TQDDEAAADAtZghlzX6+K6OTmgAAAAAAAAAgNNQhAYAAAAAAAAAOA3DcQAAAOQCqTY3pdqM7z9ItRkdAQAAAFyNGXJZ8tg7Y/w3EQAAAAAAAABAjkURGgAAAAAAAADgNAzHAQAAkAtYZZHVBP0HVnEdIwAAAG6NGXJZ8tg7Y/w3EQAAAAAAAABAjkUnNAAAQC6QKotSZTE6DFPEAAAAANdihlzW6OO7OjqhAQAAAAAAAABOQxEaAAAAAAAAAOA0DMcBAACQC6Ta3JRqM77/INXGhC4AAAC4NWbIZclj74zx30QAAAAAAAAAADkWRWgAAAAAAAAAgNMwHAcAAEAuYJVFVhPM6G2GGAAAAOBazJDLGn18V0cnNAAAAAAAAADAaShCAwAAAAAAAACchuE4AAAAcgGr3JRqgv4Dq5hVHAAAALfGDLkseeydMf6bCAAAAAAAAAAgx6IIDQAAAAAAAABwGorQAHKMJk2aqEmTJnft8UqVKqXu3bvftcfLTdauXSuLxaK1a9caHQqA/5dqczPNAgA
"text/plain": [
"<Figure size 1700x1700 with 7 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"\n",
"\n",
"num_models = len(results_classification)\n",
"num_rows = (num_models // 2) + (num_models % 2) # Количество строк для подграфиков\n",
"_, ax = plt.subplots(num_rows, 2, figsize=(17, 17), sharex=False, sharey=False)\n",
"\n",
"for index, (name, metrics) in enumerate(results_classification.items()):\n",
" c_matrix = metrics[\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Under 30\", \"30-40\", \"40-50\", \"50-60\", \"60-70\", \"70-80\", \"80+\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(name)\n",
"\n",
"# Корректировка расположения графиков\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Вывод: возраст удалось предсказать чуть успешнее. Но всё же, датасет не имеет в себе необходимых данных для более точных предсказаний"
]
2024-11-15 22:35:48 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}