коммит
This commit is contained in:
parent
f4ef300f9a
commit
3ee2967c60
716
lab_4/lab4.ipynb
716
lab_4/lab4.ipynb
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -25,8 +25,8 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Определим бизнес цели:\n",
|
||||
"## 1- Прогнозирование состояния миллиардера(регрессия)\n",
|
||||
"## 2- Прогнозирование возраста миллиардера(классификация)"
|
||||
"## 1- Прогнозирование возраста миллиардера(классификация)\n",
|
||||
"## 2- Прогнозирование состояния миллиардера(регрессия)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -3159,6 +3159,716 @@
|
||||
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Задача регрессии"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
|
||||
"y = df['Networth'] # Целевая переменная для регрессии\n",
|
||||
"\n",
|
||||
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>prepocessing_num__Age</th>\n",
|
||||
" <th>prepocessing_cat__Country_Argentina</th>\n",
|
||||
" <th>prepocessing_cat__Country_Australia</th>\n",
|
||||
" <th>prepocessing_cat__Country_Austria</th>\n",
|
||||
" <th>prepocessing_cat__Country_Barbados</th>\n",
|
||||
" <th>prepocessing_cat__Country_Belgium</th>\n",
|
||||
" <th>prepocessing_cat__Country_Belize</th>\n",
|
||||
" <th>prepocessing_cat__Country_Brazil</th>\n",
|
||||
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
|
||||
" <th>prepocessing_cat__Country_Canada</th>\n",
|
||||
" <th>...</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Logistics</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Media & Entertainment</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Metals & Mining</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Service</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Sports</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Technology</th>\n",
|
||||
" <th>prepocessing_cat__Industry_Telecom</th>\n",
|
||||
" <th>prepocessing_cat__Industry_diversified</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>582</th>\n",
|
||||
" <td>-0.109934</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>48</th>\n",
|
||||
" <td>1.079079</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1772</th>\n",
|
||||
" <td>1.004766</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>964</th>\n",
|
||||
" <td>-0.407187</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2213</th>\n",
|
||||
" <td>1.302019</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1638</th>\n",
|
||||
" <td>1.227706</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1095</th>\n",
|
||||
" <td>0.856139</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1130</th>\n",
|
||||
" <td>0.781826</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1294</th>\n",
|
||||
" <td>0.335946</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>860</th>\n",
|
||||
" <td>0.558886</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>2080 rows × 855 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
|
||||
"582 -0.109934 0.0 \n",
|
||||
"48 1.079079 0.0 \n",
|
||||
"1772 1.004766 0.0 \n",
|
||||
"964 -0.407187 0.0 \n",
|
||||
"2213 1.302019 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 1.227706 0.0 \n",
|
||||
"1095 0.856139 0.0 \n",
|
||||
"1130 0.781826 0.0 \n",
|
||||
"1294 0.335946 0.0 \n",
|
||||
"860 0.558886 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
|
||||
"582 0.0 0.0 \n",
|
||||
"48 0.0 0.0 \n",
|
||||
"1772 1.0 0.0 \n",
|
||||
"964 0.0 0.0 \n",
|
||||
"2213 0.0 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 0.0 0.0 \n",
|
||||
"1095 0.0 0.0 \n",
|
||||
"1130 0.0 0.0 \n",
|
||||
"1294 0.0 0.0 \n",
|
||||
"860 0.0 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
|
||||
"582 0.0 0.0 \n",
|
||||
"48 0.0 0.0 \n",
|
||||
"1772 0.0 0.0 \n",
|
||||
"964 0.0 0.0 \n",
|
||||
"2213 0.0 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 0.0 0.0 \n",
|
||||
"1095 0.0 0.0 \n",
|
||||
"1130 0.0 0.0 \n",
|
||||
"1294 0.0 0.0 \n",
|
||||
"860 0.0 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
|
||||
"582 0.0 0.0 \n",
|
||||
"48 0.0 0.0 \n",
|
||||
"1772 0.0 0.0 \n",
|
||||
"964 0.0 0.0 \n",
|
||||
"2213 0.0 1.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 0.0 0.0 \n",
|
||||
"1095 0.0 1.0 \n",
|
||||
"1130 0.0 0.0 \n",
|
||||
"1294 0.0 0.0 \n",
|
||||
"860 0.0 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
|
||||
"582 0.0 0.0 \n",
|
||||
"48 0.0 0.0 \n",
|
||||
"1772 0.0 0.0 \n",
|
||||
"964 0.0 0.0 \n",
|
||||
"2213 0.0 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 0.0 0.0 \n",
|
||||
"1095 0.0 0.0 \n",
|
||||
"1130 0.0 0.0 \n",
|
||||
"1294 0.0 0.0 \n",
|
||||
"860 0.0 0.0 \n",
|
||||
"\n",
|
||||
" ... prepocessing_cat__Industry_Logistics \\\n",
|
||||
"582 ... 0.0 \n",
|
||||
"48 ... 0.0 \n",
|
||||
"1772 ... 0.0 \n",
|
||||
"964 ... 0.0 \n",
|
||||
"2213 ... 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 ... 0.0 \n",
|
||||
"1095 ... 0.0 \n",
|
||||
"1130 ... 0.0 \n",
|
||||
"1294 ... 0.0 \n",
|
||||
"860 ... 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Manufacturing \\\n",
|
||||
"582 0.0 \n",
|
||||
"48 1.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 1.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 1.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Media & Entertainment \\\n",
|
||||
"582 0.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Metals & Mining \\\n",
|
||||
"582 0.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Real Estate \\\n",
|
||||
"582 1.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 1.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
|
||||
"582 0.0 0.0 \n",
|
||||
"48 0.0 0.0 \n",
|
||||
"1772 0.0 0.0 \n",
|
||||
"964 0.0 0.0 \n",
|
||||
"2213 0.0 0.0 \n",
|
||||
"... ... ... \n",
|
||||
"1638 0.0 0.0 \n",
|
||||
"1095 0.0 0.0 \n",
|
||||
"1130 0.0 0.0 \n",
|
||||
"1294 0.0 0.0 \n",
|
||||
"860 0.0 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Technology \\\n",
|
||||
"582 0.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_Telecom \\\n",
|
||||
"582 0.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
" prepocessing_cat__Industry_diversified \n",
|
||||
"582 0.0 \n",
|
||||
"48 0.0 \n",
|
||||
"1772 0.0 \n",
|
||||
"964 0.0 \n",
|
||||
"2213 0.0 \n",
|
||||
"... ... \n",
|
||||
"1638 0.0 \n",
|
||||
"1095 0.0 \n",
|
||||
"1130 0.0 \n",
|
||||
"1294 0.0 \n",
|
||||
"860 0.0 \n",
|
||||
"\n",
|
||||
"[2080 rows x 855 columns]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.compose import ColumnTransformer\n",
|
||||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||||
"from sklearn.impute import SimpleImputer\n",
|
||||
"from sklearn.pipeline import Pipeline\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
|
||||
"columns_to_drop = []\n",
|
||||
"\n",
|
||||
"num_columns = [\n",
|
||||
" column\n",
|
||||
" for column in X_train.columns\n",
|
||||
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
|
||||
"]\n",
|
||||
"cat_columns = [\n",
|
||||
" column\n",
|
||||
" for column in X_train.columns\n",
|
||||
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Предобработка числовых данных\n",
|
||||
"num_imputer = SimpleImputer(strategy=\"median\")\n",
|
||||
"num_scaler = StandardScaler()\n",
|
||||
"preprocessing_num = Pipeline(\n",
|
||||
" [\n",
|
||||
" (\"imputer\", num_imputer),\n",
|
||||
" (\"scaler\", num_scaler),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Предобработка категориальных данных\n",
|
||||
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
|
||||
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
|
||||
"preprocessing_cat = Pipeline(\n",
|
||||
" [\n",
|
||||
" (\"imputer\", cat_imputer),\n",
|
||||
" (\"encoder\", cat_encoder),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Общая предобработка признаков\n",
|
||||
"features_preprocessing = ColumnTransformer(\n",
|
||||
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
|
||||
" transformers=[\n",
|
||||
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
|
||||
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
|
||||
" ],\n",
|
||||
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Итоговый конвейер\n",
|
||||
"pipeline_end = Pipeline(\n",
|
||||
" [\n",
|
||||
" (\"features_preprocessing\", features_preprocessing),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Преобразуем данные\n",
|
||||
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
|
||||
"\n",
|
||||
"# Создаем DataFrame с правильными именами колонок\n",
|
||||
"preprocessed_df = pd.DataFrame(\n",
|
||||
" preprocessing_result,\n",
|
||||
" columns=pipeline_end.get_feature_names_out(),\n",
|
||||
" index=X_train.index, # Сохраняем индексы\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"preprocessed_df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training LogisticRegression...\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "\nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[1;32mIn[7], line 44\u001b[0m\n\u001b[0;32m 42\u001b[0m param_grid \u001b[38;5;241m=\u001b[39m param_grids_classification[name]\n\u001b[0;32m 43\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m RandomizedSearchCV(pipeline, param_grid, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 44\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;66;03m# Лучшая модель\u001b[39;00m\n\u001b[0;32m 47\u001b[0m best_model \u001b[38;5;241m=\u001b[39m grid_search\u001b[38;5;241m.\u001b[39mbest_estimator_\n",
|
||||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[1;34m(self, X, y, **params)\u001b[0m\n\u001b[0;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[0;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[0;32m 1015\u001b[0m )\n\u001b[0;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[0;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[0;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
|
||||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1960\u001b[0m, in \u001b[0;36mRandomizedSearchCV._run_search\u001b[1;34m(self, evaluate_candidates)\u001b[0m\n\u001b[0;32m 1958\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[0;32m 1959\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search n_iter candidates from param_distributions\"\"\"\u001b[39;00m\n\u001b[1;32m-> 1960\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1961\u001b[0m \u001b[43m \u001b[49m\u001b[43mParameterSampler\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1962\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_distributions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom_state\u001b[49m\n\u001b[0;32m 1963\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1964\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:996\u001b[0m, in \u001b[0;36mBaseSearchCV.fit.<locals>.evaluate_candidates\u001b[1;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[0;32m 989\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m!=\u001b[39m n_candidates \u001b[38;5;241m*\u001b[39m n_splits:\n\u001b[0;32m 990\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 991\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv.split and cv.get_n_splits returned \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minconsistent results. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 993\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msplits, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(n_splits, \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m n_candidates)\n\u001b[0;32m 994\u001b[0m )\n\u001b[1;32m--> 996\u001b[0m \u001b[43m_warn_or_raise_about_fit_failures\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;66;03m# For callable self.scoring, the return type is only know after\u001b[39;00m\n\u001b[0;32m 999\u001b[0m \u001b[38;5;66;03m# calling. If the return type is a dictionary, the error scores\u001b[39;00m\n\u001b[0;32m 1000\u001b[0m \u001b[38;5;66;03m# can now be inserted with the correct key. The type checking\u001b[39;00m\n\u001b[0;32m 1001\u001b[0m \u001b[38;5;66;03m# of out will be done in `_insert_error_scores`.\u001b[39;00m\n\u001b[0;32m 1002\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscoring):\n",
|
||||
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:529\u001b[0m, in \u001b[0;36m_warn_or_raise_about_fit_failures\u001b[1;34m(results, error_score)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_failed_fits \u001b[38;5;241m==\u001b[39m num_fits:\n\u001b[0;32m 523\u001b[0m all_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 524\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mAll the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt is very likely that your model is misconfigured.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 526\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou can try to debug the error by setting error_score=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 528\u001b[0m )\n\u001b[1;32m--> 529\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(all_fits_failed_message)\n\u001b[0;32m 531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 532\u001b[0m some_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 533\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mnum_failed_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed out of a total of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 534\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe score on these train-test partitions for these parameters\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 539\u001b[0m )\n",
|
||||
"\u001b[1;31mValueError\u001b[0m: \nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn.model_selection import RandomizedSearchCV\n",
|
||||
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||||
"from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Модели и параметры\n",
|
||||
"models_classification = {\n",
|
||||
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
|
||||
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
|
||||
" \"KNN\": KNeighborsClassifier()\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"param_grids_classification = {\n",
|
||||
" \"LogisticRegression\": {\n",
|
||||
" 'model__C': [0.1, 1, 10]\n",
|
||||
" },\n",
|
||||
" \"RandomForestClassifier\": {\n",
|
||||
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
|
||||
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
|
||||
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
|
||||
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
|
||||
" },\n",
|
||||
" \"KNN\": {\n",
|
||||
" 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
|
||||
" 'model__weights': ['uniform', 'distance']\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Результаты\n",
|
||||
"results_classification = {}\n",
|
||||
"\n",
|
||||
"# Перебор моделей\n",
|
||||
"for name, model in models_classification.items():\n",
|
||||
" print(f\"Training {name}...\")\n",
|
||||
" pipeline = Pipeline(steps=[\n",
|
||||
" ('features_preprocessing', features_preprocessing),\n",
|
||||
" ('model', model)\n",
|
||||
" ])\n",
|
||||
" \n",
|
||||
" param_grid = param_grids_classification[name]\n",
|
||||
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
|
||||
" grid_search.fit(X_train, y_train)\n",
|
||||
"\n",
|
||||
" # Лучшая модель\n",
|
||||
" best_model = grid_search.best_estimator_\n",
|
||||
" y_pred = best_model.predict(X_test)\n",
|
||||
"\n",
|
||||
" # Метрики\n",
|
||||
" acc = accuracy_score(y_test, y_pred)\n",
|
||||
" f1 = f1_score(y_test, y_pred)\n",
|
||||
"\n",
|
||||
" # Вычисление матрицы ошибок\n",
|
||||
" c_matrix = confusion_matrix(y_test, y_pred)\n",
|
||||
"\n",
|
||||
" # Сохранение результатов\n",
|
||||
" results_classification[name] = {\n",
|
||||
" \"Best Params\": grid_search.best_params_,\n",
|
||||
" \"Accuracy\": acc,\n",
|
||||
" \"F1 Score\": f1,\n",
|
||||
" \"Confusion_matrix\": c_matrix\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# Печать результатов\n",
|
||||
"for name, metrics in results_classification.items():\n",
|
||||
" print(f\"\\nModel: {name}\")\n",
|
||||
" for metric, value in metrics.items():\n",
|
||||
" print(f\"{metric}: {value}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Loading…
x
Reference in New Issue
Block a user