коммит

This commit is contained in:
annalyovushkina@yandex.ru 2024-11-29 00:30:06 +04:00
parent f4ef300f9a
commit 3ee2967c60

View File

@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -25,8 +25,8 @@
"metadata": {},
"source": [
"# Определим бизнес цели:\n",
"## 1- Прогнозирование состояния миллиардера(регрессия)\n",
"## 2- Прогнозирование возраста миллиардера(классификация)"
"## 1- Прогнозирование возраста миллиардера(классификация)\n",
"## 2- Прогнозирование состояния миллиардера(регрессия)"
]
},
{
@ -3159,6 +3159,716 @@
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Задача регрессии"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
"y = df['Networth'] # Целевая переменная для регрессии\n",
"\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>prepocessing_num__Age</th>\n",
" <th>prepocessing_cat__Country_Argentina</th>\n",
" <th>prepocessing_cat__Country_Australia</th>\n",
" <th>prepocessing_cat__Country_Austria</th>\n",
" <th>prepocessing_cat__Country_Barbados</th>\n",
" <th>prepocessing_cat__Country_Belgium</th>\n",
" <th>prepocessing_cat__Country_Belize</th>\n",
" <th>prepocessing_cat__Country_Brazil</th>\n",
" <th>prepocessing_cat__Country_Bulgaria</th>\n",
" <th>prepocessing_cat__Country_Canada</th>\n",
" <th>...</th>\n",
" <th>prepocessing_cat__Industry_Logistics</th>\n",
" <th>prepocessing_cat__Industry_Manufacturing</th>\n",
" <th>prepocessing_cat__Industry_Media &amp; Entertainment</th>\n",
" <th>prepocessing_cat__Industry_Metals &amp; Mining</th>\n",
" <th>prepocessing_cat__Industry_Real Estate</th>\n",
" <th>prepocessing_cat__Industry_Service</th>\n",
" <th>prepocessing_cat__Industry_Sports</th>\n",
" <th>prepocessing_cat__Industry_Technology</th>\n",
" <th>prepocessing_cat__Industry_Telecom</th>\n",
" <th>prepocessing_cat__Industry_diversified</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>582</th>\n",
" <td>-0.109934</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>48</th>\n",
" <td>1.079079</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1772</th>\n",
" <td>1.004766</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>964</th>\n",
" <td>-0.407187</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2213</th>\n",
" <td>1.302019</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1638</th>\n",
" <td>1.227706</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1095</th>\n",
" <td>0.856139</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1130</th>\n",
" <td>0.781826</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1294</th>\n",
" <td>0.335946</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0.558886</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>2080 rows × 855 columns</p>\n",
"</div>"
],
"text/plain": [
" prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
"582 -0.109934 0.0 \n",
"48 1.079079 0.0 \n",
"1772 1.004766 0.0 \n",
"964 -0.407187 0.0 \n",
"2213 1.302019 0.0 \n",
"... ... ... \n",
"1638 1.227706 0.0 \n",
"1095 0.856139 0.0 \n",
"1130 0.781826 0.0 \n",
"1294 0.335946 0.0 \n",
"860 0.558886 0.0 \n",
"\n",
" prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 1.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 1.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 1.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" ... prepocessing_cat__Industry_Logistics \\\n",
"582 ... 0.0 \n",
"48 ... 0.0 \n",
"1772 ... 0.0 \n",
"964 ... 0.0 \n",
"2213 ... 0.0 \n",
"... ... ... \n",
"1638 ... 0.0 \n",
"1095 ... 0.0 \n",
"1130 ... 0.0 \n",
"1294 ... 0.0 \n",
"860 ... 0.0 \n",
"\n",
" prepocessing_cat__Industry_Manufacturing \\\n",
"582 0.0 \n",
"48 1.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 1.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 1.0 \n",
"\n",
" prepocessing_cat__Industry_Media & Entertainment \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Metals & Mining \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Real Estate \\\n",
"582 1.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 1.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
"582 0.0 0.0 \n",
"48 0.0 0.0 \n",
"1772 0.0 0.0 \n",
"964 0.0 0.0 \n",
"2213 0.0 0.0 \n",
"... ... ... \n",
"1638 0.0 0.0 \n",
"1095 0.0 0.0 \n",
"1130 0.0 0.0 \n",
"1294 0.0 0.0 \n",
"860 0.0 0.0 \n",
"\n",
" prepocessing_cat__Industry_Technology \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_Telecom \\\n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
" prepocessing_cat__Industry_diversified \n",
"582 0.0 \n",
"48 0.0 \n",
"1772 0.0 \n",
"964 0.0 \n",
"2213 0.0 \n",
"... ... \n",
"1638 0.0 \n",
"1095 0.0 \n",
"1130 0.0 \n",
"1294 0.0 \n",
"860 0.0 \n",
"\n",
"[2080 rows x 855 columns]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"import pandas as pd\n",
"\n",
"# Исправляем ColumnTransformer с сохранением имен колонок\n",
"columns_to_drop = []\n",
"\n",
"num_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in X_train.columns\n",
" if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
"]\n",
"\n",
"# Предобработка числовых данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Предобработка категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Общая предобработка признаков\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=True, # Сохраняем имена колонок\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
")\n",
"\n",
"# Итоговый конвейер\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" ]\n",
")\n",
"\n",
"# Преобразуем данные\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"\n",
"# Создаем DataFrame с правильными именами колонок\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
" index=X_train.index, # Сохраняем индексы\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training LogisticRegression...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
" warnings.warn(\n"
]
},
{
"ename": "ValueError",
"evalue": "\nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[7], line 44\u001b[0m\n\u001b[0;32m 42\u001b[0m param_grid \u001b[38;5;241m=\u001b[39m param_grids_classification[name]\n\u001b[0;32m 43\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m RandomizedSearchCV(pipeline, param_grid, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 44\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;66;03m# Лучшая модель\u001b[39;00m\n\u001b[0;32m 47\u001b[0m best_model \u001b[38;5;241m=\u001b[39m grid_search\u001b[38;5;241m.\u001b[39mbest_estimator_\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context.<locals>.decorator.<locals>.wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[1;34m(self, X, y, **params)\u001b[0m\n\u001b[0;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[0;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[0;32m 1015\u001b[0m )\n\u001b[0;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[0;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[0;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1960\u001b[0m, in \u001b[0;36mRandomizedSearchCV._run_search\u001b[1;34m(self, evaluate_candidates)\u001b[0m\n\u001b[0;32m 1958\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[0;32m 1959\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search n_iter candidates from param_distributions\"\"\"\u001b[39;00m\n\u001b[1;32m-> 1960\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1961\u001b[0m \u001b[43m \u001b[49m\u001b[43mParameterSampler\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1962\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_distributions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom_state\u001b[49m\n\u001b[0;32m 1963\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1964\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:996\u001b[0m, in \u001b[0;36mBaseSearchCV.fit.<locals>.evaluate_candidates\u001b[1;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[0;32m 989\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m!=\u001b[39m n_candidates \u001b[38;5;241m*\u001b[39m n_splits:\n\u001b[0;32m 990\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 991\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv.split and cv.get_n_splits returned \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minconsistent results. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 993\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msplits, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(n_splits, \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m n_candidates)\n\u001b[0;32m 994\u001b[0m )\n\u001b[1;32m--> 996\u001b[0m \u001b[43m_warn_or_raise_about_fit_failures\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;66;03m# For callable self.scoring, the return type is only know after\u001b[39;00m\n\u001b[0;32m 999\u001b[0m \u001b[38;5;66;03m# calling. If the return type is a dictionary, the error scores\u001b[39;00m\n\u001b[0;32m 1000\u001b[0m \u001b[38;5;66;03m# can now be inserted with the correct key. The type checking\u001b[39;00m\n\u001b[0;32m 1001\u001b[0m \u001b[38;5;66;03m# of out will be done in `_insert_error_scores`.\u001b[39;00m\n\u001b[0;32m 1002\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscoring):\n",
"File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:529\u001b[0m, in \u001b[0;36m_warn_or_raise_about_fit_failures\u001b[1;34m(results, error_score)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_failed_fits \u001b[38;5;241m==\u001b[39m num_fits:\n\u001b[0;32m 523\u001b[0m all_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 524\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mAll the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt is very likely that your model is misconfigured.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 526\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou can try to debug the error by setting error_score=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 528\u001b[0m )\n\u001b[1;32m--> 529\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(all_fits_failed_message)\n\u001b[0;32m 531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 532\u001b[0m some_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 533\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mnum_failed_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed out of a total of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 534\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe score on these train-test partitions for these parameters\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 539\u001b[0m )\n",
"\u001b[1;31mValueError\u001b[0m: \nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import RandomizedSearchCV\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
"\n",
"\n",
"# Модели и параметры\n",
"models_classification = {\n",
" \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
" \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
" \"KNN\": KNeighborsClassifier()\n",
"}\n",
"\n",
"param_grids_classification = {\n",
" \"LogisticRegression\": {\n",
" 'model__C': [0.1, 1, 10]\n",
" },\n",
" \"RandomForestClassifier\": {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
" },\n",
" \"KNN\": {\n",
" 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
" 'model__weights': ['uniform', 'distance']\n",
" }\n",
"}\n",
"\n",
"# Результаты\n",
"results_classification = {}\n",
"\n",
"# Перебор моделей\n",
"for name, model in models_classification.items():\n",
" print(f\"Training {name}...\")\n",
" pipeline = Pipeline(steps=[\n",
" ('features_preprocessing', features_preprocessing),\n",
" ('model', model)\n",
" ])\n",
" \n",
" param_grid = param_grids_classification[name]\n",
" grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
" grid_search.fit(X_train, y_train)\n",
"\n",
" # Лучшая модель\n",
" best_model = grid_search.best_estimator_\n",
" y_pred = best_model.predict(X_test)\n",
"\n",
" # Метрики\n",
" acc = accuracy_score(y_test, y_pred)\n",
" f1 = f1_score(y_test, y_pred)\n",
"\n",
" # Вычисление матрицы ошибок\n",
" c_matrix = confusion_matrix(y_test, y_pred)\n",
"\n",
" # Сохранение результатов\n",
" results_classification[name] = {\n",
" \"Best Params\": grid_search.best_params_,\n",
" \"Accuracy\": acc,\n",
" \"F1 Score\": f1,\n",
" \"Confusion_matrix\": c_matrix\n",
" }\n",
"\n",
"# Печать результатов\n",
"for name, metrics in results_classification.items():\n",
" print(f\"\\nModel: {name}\")\n",
" for metric, value in metrics.items():\n",
" print(f\"{metric}: {value}\")"
]
}
],
"metadata": {