diff --git a/lab_4/lab4.ipynb b/lab_4/lab4.ipynb
index c240afb..c681fca 100644
--- a/lab_4/lab4.ipynb
+++ b/lab_4/lab4.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -25,8 +25,8 @@
"metadata": {},
"source": [
"# Определим бизнес цели:\n",
- "## 1- Прогнозирование состояния миллиардера(регрессия)\n",
- "## 2- Прогнозирование возраста миллиардера(классификация)"
+ "## 1- Прогнозирование возраста миллиардера(классификация)\n",
+ "## 2- Прогнозирование состояния миллиардера(регрессия)"
]
},
{
@@ -3159,6 +3159,716 @@
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Задача регрессии"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "X = df.drop(columns=['Networth','Rank ', 'Name']) # Признаки\n",
+ "y = df['Networth'] # Целевая переменная для регрессии\n",
+ "\n",
+ "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " prepocessing_num__Age | \n",
+ " prepocessing_cat__Country_Argentina | \n",
+ " prepocessing_cat__Country_Australia | \n",
+ " prepocessing_cat__Country_Austria | \n",
+ " prepocessing_cat__Country_Barbados | \n",
+ " prepocessing_cat__Country_Belgium | \n",
+ " prepocessing_cat__Country_Belize | \n",
+ " prepocessing_cat__Country_Brazil | \n",
+ " prepocessing_cat__Country_Bulgaria | \n",
+ " prepocessing_cat__Country_Canada | \n",
+ " ... | \n",
+ " prepocessing_cat__Industry_Logistics | \n",
+ " prepocessing_cat__Industry_Manufacturing | \n",
+ " prepocessing_cat__Industry_Media & Entertainment | \n",
+ " prepocessing_cat__Industry_Metals & Mining | \n",
+ " prepocessing_cat__Industry_Real Estate | \n",
+ " prepocessing_cat__Industry_Service | \n",
+ " prepocessing_cat__Industry_Sports | \n",
+ " prepocessing_cat__Industry_Technology | \n",
+ " prepocessing_cat__Industry_Telecom | \n",
+ " prepocessing_cat__Industry_diversified | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 582 | \n",
+ " -0.109934 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 1.079079 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1772 | \n",
+ " 1.004766 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 964 | \n",
+ " -0.407187 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 2213 | \n",
+ " 1.302019 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1638 | \n",
+ " 1.227706 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1095 | \n",
+ " 0.856139 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1130 | \n",
+ " 0.781826 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 1294 | \n",
+ " 0.335946 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 860 | \n",
+ " 0.558886 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2080 rows × 855 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " prepocessing_num__Age prepocessing_cat__Country_Argentina \\\n",
+ "582 -0.109934 0.0 \n",
+ "48 1.079079 0.0 \n",
+ "1772 1.004766 0.0 \n",
+ "964 -0.407187 0.0 \n",
+ "2213 1.302019 0.0 \n",
+ "... ... ... \n",
+ "1638 1.227706 0.0 \n",
+ "1095 0.856139 0.0 \n",
+ "1130 0.781826 0.0 \n",
+ "1294 0.335946 0.0 \n",
+ "860 0.558886 0.0 \n",
+ "\n",
+ " prepocessing_cat__Country_Australia prepocessing_cat__Country_Austria \\\n",
+ "582 0.0 0.0 \n",
+ "48 0.0 0.0 \n",
+ "1772 1.0 0.0 \n",
+ "964 0.0 0.0 \n",
+ "2213 0.0 0.0 \n",
+ "... ... ... \n",
+ "1638 0.0 0.0 \n",
+ "1095 0.0 0.0 \n",
+ "1130 0.0 0.0 \n",
+ "1294 0.0 0.0 \n",
+ "860 0.0 0.0 \n",
+ "\n",
+ " prepocessing_cat__Country_Barbados prepocessing_cat__Country_Belgium \\\n",
+ "582 0.0 0.0 \n",
+ "48 0.0 0.0 \n",
+ "1772 0.0 0.0 \n",
+ "964 0.0 0.0 \n",
+ "2213 0.0 0.0 \n",
+ "... ... ... \n",
+ "1638 0.0 0.0 \n",
+ "1095 0.0 0.0 \n",
+ "1130 0.0 0.0 \n",
+ "1294 0.0 0.0 \n",
+ "860 0.0 0.0 \n",
+ "\n",
+ " prepocessing_cat__Country_Belize prepocessing_cat__Country_Brazil \\\n",
+ "582 0.0 0.0 \n",
+ "48 0.0 0.0 \n",
+ "1772 0.0 0.0 \n",
+ "964 0.0 0.0 \n",
+ "2213 0.0 1.0 \n",
+ "... ... ... \n",
+ "1638 0.0 0.0 \n",
+ "1095 0.0 1.0 \n",
+ "1130 0.0 0.0 \n",
+ "1294 0.0 0.0 \n",
+ "860 0.0 0.0 \n",
+ "\n",
+ " prepocessing_cat__Country_Bulgaria prepocessing_cat__Country_Canada \\\n",
+ "582 0.0 0.0 \n",
+ "48 0.0 0.0 \n",
+ "1772 0.0 0.0 \n",
+ "964 0.0 0.0 \n",
+ "2213 0.0 0.0 \n",
+ "... ... ... \n",
+ "1638 0.0 0.0 \n",
+ "1095 0.0 0.0 \n",
+ "1130 0.0 0.0 \n",
+ "1294 0.0 0.0 \n",
+ "860 0.0 0.0 \n",
+ "\n",
+ " ... prepocessing_cat__Industry_Logistics \\\n",
+ "582 ... 0.0 \n",
+ "48 ... 0.0 \n",
+ "1772 ... 0.0 \n",
+ "964 ... 0.0 \n",
+ "2213 ... 0.0 \n",
+ "... ... ... \n",
+ "1638 ... 0.0 \n",
+ "1095 ... 0.0 \n",
+ "1130 ... 0.0 \n",
+ "1294 ... 0.0 \n",
+ "860 ... 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Manufacturing \\\n",
+ "582 0.0 \n",
+ "48 1.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 1.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 1.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Media & Entertainment \\\n",
+ "582 0.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Metals & Mining \\\n",
+ "582 0.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Real Estate \\\n",
+ "582 1.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 1.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Service prepocessing_cat__Industry_Sports \\\n",
+ "582 0.0 0.0 \n",
+ "48 0.0 0.0 \n",
+ "1772 0.0 0.0 \n",
+ "964 0.0 0.0 \n",
+ "2213 0.0 0.0 \n",
+ "... ... ... \n",
+ "1638 0.0 0.0 \n",
+ "1095 0.0 0.0 \n",
+ "1130 0.0 0.0 \n",
+ "1294 0.0 0.0 \n",
+ "860 0.0 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Technology \\\n",
+ "582 0.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_Telecom \\\n",
+ "582 0.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ " prepocessing_cat__Industry_diversified \n",
+ "582 0.0 \n",
+ "48 0.0 \n",
+ "1772 0.0 \n",
+ "964 0.0 \n",
+ "2213 0.0 \n",
+ "... ... \n",
+ "1638 0.0 \n",
+ "1095 0.0 \n",
+ "1130 0.0 \n",
+ "1294 0.0 \n",
+ "860 0.0 \n",
+ "\n",
+ "[2080 rows x 855 columns]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from sklearn.compose import ColumnTransformer\n",
+ "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
+ "from sklearn.impute import SimpleImputer\n",
+ "from sklearn.pipeline import Pipeline\n",
+ "import pandas as pd\n",
+ "\n",
+ "# Исправляем ColumnTransformer с сохранением имен колонок\n",
+ "columns_to_drop = []\n",
+ "\n",
+ "num_columns = [\n",
+ " column\n",
+ " for column in X_train.columns\n",
+ " if column not in columns_to_drop and X_train[column].dtype != \"object\"\n",
+ "]\n",
+ "cat_columns = [\n",
+ " column\n",
+ " for column in X_train.columns\n",
+ " if column not in columns_to_drop and X_train[column].dtype == \"object\"\n",
+ "]\n",
+ "\n",
+ "# Предобработка числовых данных\n",
+ "num_imputer = SimpleImputer(strategy=\"median\")\n",
+ "num_scaler = StandardScaler()\n",
+ "preprocessing_num = Pipeline(\n",
+ " [\n",
+ " (\"imputer\", num_imputer),\n",
+ " (\"scaler\", num_scaler),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Предобработка категориальных данных\n",
+ "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
+ "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
+ "preprocessing_cat = Pipeline(\n",
+ " [\n",
+ " (\"imputer\", cat_imputer),\n",
+ " (\"encoder\", cat_encoder),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Общая предобработка признаков\n",
+ "features_preprocessing = ColumnTransformer(\n",
+ " verbose_feature_names_out=True, # Сохраняем имена колонок\n",
+ " transformers=[\n",
+ " (\"prepocessing_num\", preprocessing_num, num_columns),\n",
+ " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
+ " ],\n",
+ " remainder=\"drop\" # Убираем неиспользуемые столбцы\n",
+ ")\n",
+ "\n",
+ "# Итоговый конвейер\n",
+ "pipeline_end = Pipeline(\n",
+ " [\n",
+ " (\"features_preprocessing\", features_preprocessing),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "# Преобразуем данные\n",
+ "preprocessing_result = pipeline_end.fit_transform(X_train)\n",
+ "\n",
+ "# Создаем DataFrame с правильными именами колонок\n",
+ "preprocessed_df = pd.DataFrame(\n",
+ " preprocessing_result,\n",
+ " columns=pipeline_end.get_feature_names_out(),\n",
+ " index=X_train.index, # Сохраняем индексы\n",
+ ")\n",
+ "\n",
+ "preprocessed_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training LogisticRegression...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:320: UserWarning: The total space of parameters 3 is smaller than n_iter=10. Running 3 iterations. For exhaustive searches, use GridSearchCV.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "ename": "ValueError",
+ "evalue": "\nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[1;32mIn[7], line 44\u001b[0m\n\u001b[0;32m 42\u001b[0m param_grid \u001b[38;5;241m=\u001b[39m param_grids_classification[name]\n\u001b[0;32m 43\u001b[0m grid_search \u001b[38;5;241m=\u001b[39m RandomizedSearchCV(pipeline, param_grid, cv\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, scoring\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mf1\u001b[39m\u001b[38;5;124m'\u001b[39m, n_jobs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m---> 44\u001b[0m \u001b[43mgrid_search\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;66;03m# Лучшая модель\u001b[39;00m\n\u001b[0;32m 47\u001b[0m best_model \u001b[38;5;241m=\u001b[39m grid_search\u001b[38;5;241m.\u001b[39mbest_estimator_\n",
+ "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473\u001b[0m, in \u001b[0;36m_fit_context..decorator..wrapper\u001b[1;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1466\u001b[0m estimator\u001b[38;5;241m.\u001b[39m_validate_params()\n\u001b[0;32m 1468\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m config_context(\n\u001b[0;32m 1469\u001b[0m skip_parameter_validation\u001b[38;5;241m=\u001b[39m(\n\u001b[0;32m 1470\u001b[0m prefer_skip_nested_validation \u001b[38;5;129;01mor\u001b[39;00m global_skip_validation\n\u001b[0;32m 1471\u001b[0m )\n\u001b[0;32m 1472\u001b[0m ):\n\u001b[1;32m-> 1473\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfit_method\u001b[49m\u001b[43m(\u001b[49m\u001b[43mestimator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1019\u001b[0m, in \u001b[0;36mBaseSearchCV.fit\u001b[1;34m(self, X, y, **params)\u001b[0m\n\u001b[0;32m 1013\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_results(\n\u001b[0;32m 1014\u001b[0m all_candidate_params, n_splits, all_out, all_more_results\n\u001b[0;32m 1015\u001b[0m )\n\u001b[0;32m 1017\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m results\n\u001b[1;32m-> 1019\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mevaluate_candidates\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1021\u001b[0m \u001b[38;5;66;03m# multimetric is determined here because in the case of a callable\u001b[39;00m\n\u001b[0;32m 1022\u001b[0m \u001b[38;5;66;03m# self.scoring the return type is only known after calling\u001b[39;00m\n\u001b[0;32m 1023\u001b[0m first_test_score \u001b[38;5;241m=\u001b[39m all_out[\u001b[38;5;241m0\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_scores\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
+ "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1960\u001b[0m, in \u001b[0;36mRandomizedSearchCV._run_search\u001b[1;34m(self, evaluate_candidates)\u001b[0m\n\u001b[0;32m 1958\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_run_search\u001b[39m(\u001b[38;5;28mself\u001b[39m, evaluate_candidates):\n\u001b[0;32m 1959\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Search n_iter candidates from param_distributions\"\"\"\u001b[39;00m\n\u001b[1;32m-> 1960\u001b[0m \u001b[43mevaluate_candidates\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1961\u001b[0m \u001b[43m \u001b[49m\u001b[43mParameterSampler\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 1962\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparam_distributions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom_state\u001b[49m\n\u001b[0;32m 1963\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1964\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:996\u001b[0m, in \u001b[0;36mBaseSearchCV.fit..evaluate_candidates\u001b[1;34m(candidate_params, cv, more_results)\u001b[0m\n\u001b[0;32m 989\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m!=\u001b[39m n_candidates \u001b[38;5;241m*\u001b[39m n_splits:\n\u001b[0;32m 990\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[0;32m 991\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv.split and cv.get_n_splits returned \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 992\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minconsistent results. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 993\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msplits, got \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(n_splits, \u001b[38;5;28mlen\u001b[39m(out) \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m n_candidates)\n\u001b[0;32m 994\u001b[0m )\n\u001b[1;32m--> 996\u001b[0m \u001b[43m_warn_or_raise_about_fit_failures\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43merror_score\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 998\u001b[0m \u001b[38;5;66;03m# For callable self.scoring, the return type is only know after\u001b[39;00m\n\u001b[0;32m 999\u001b[0m \u001b[38;5;66;03m# calling. If the return type is a dictionary, the error scores\u001b[39;00m\n\u001b[0;32m 1000\u001b[0m \u001b[38;5;66;03m# can now be inserted with the correct key. The type checking\u001b[39;00m\n\u001b[0;32m 1001\u001b[0m \u001b[38;5;66;03m# of out will be done in `_insert_error_scores`.\u001b[39;00m\n\u001b[0;32m 1002\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcallable\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscoring):\n",
+ "File \u001b[1;32mc:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:529\u001b[0m, in \u001b[0;36m_warn_or_raise_about_fit_failures\u001b[1;34m(results, error_score)\u001b[0m\n\u001b[0;32m 522\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_failed_fits \u001b[38;5;241m==\u001b[39m num_fits:\n\u001b[0;32m 523\u001b[0m all_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 524\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mAll the \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 525\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIt is very likely that your model is misconfigured.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 526\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou can try to debug the error by setting error_score=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mraise\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 528\u001b[0m )\n\u001b[1;32m--> 529\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(all_fits_failed_message)\n\u001b[0;32m 531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 532\u001b[0m some_fits_failed_message \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m 533\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mnum_failed_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m fits failed out of a total of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_fits\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 534\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe score on these train-test partitions for these parameters\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 538\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBelow are more details about the failures:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mfit_errors_summary\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 539\u001b[0m )\n",
+ "\u001b[1;31mValueError\u001b[0m: \nAll the 15 fits failed.\nIt is very likely that your model is misconfigured.\nYou can try to debug the error by setting error_score='raise'.\n\nBelow are more details about the failures:\n--------------------------------------------------------------------------------\n15 fits failed with the following error:\nTraceback (most recent call last):\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n estimator.fit(X_train, y_train, **fit_params)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\pipeline.py\", line 473, in fit\n self._final_estimator.fit(Xt, y, **last_step_params[\"fit\"])\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\base.py\", line 1473, in wrapper\n return fit_method(estimator, *args, **kwargs)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py\", line 1231, in fit\n check_classification_targets(y)\n File \"c:\\Users\\annal\\aim\\.venv\\Lib\\site-packages\\sklearn\\utils\\multiclass.py\", line 219, in check_classification_targets\n raise ValueError(\nValueError: Unknown label type: continuous. Maybe you are trying to fit a classifier, which expects discrete classes on a regression target with continuous values.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.model_selection import RandomizedSearchCV\n",
+ "from sklearn.neighbors import KNeighborsClassifier\n",
+ "from sklearn.metrics import accuracy_score, confusion_matrix, f1_score\n",
+ "\n",
+ "\n",
+ "# Модели и параметры\n",
+ "models_classification = {\n",
+ " \"LogisticRegression\": LogisticRegression(max_iter=1000),\n",
+ " \"RandomForestClassifier\": RandomForestClassifier(random_state=42),\n",
+ " \"KNN\": KNeighborsClassifier()\n",
+ "}\n",
+ "\n",
+ "param_grids_classification = {\n",
+ " \"LogisticRegression\": {\n",
+ " 'model__C': [0.1, 1, 10]\n",
+ " },\n",
+ " \"RandomForestClassifier\": {\n",
+ " \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
+ " \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
+ " \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10, 20],\n",
+ " \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
+ " },\n",
+ " \"KNN\": {\n",
+ " 'model__n_neighbors': [3, 5, 7, 9, 11],\n",
+ " 'model__weights': ['uniform', 'distance']\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "# Результаты\n",
+ "results_classification = {}\n",
+ "\n",
+ "# Перебор моделей\n",
+ "for name, model in models_classification.items():\n",
+ " print(f\"Training {name}...\")\n",
+ " pipeline = Pipeline(steps=[\n",
+ " ('features_preprocessing', features_preprocessing),\n",
+ " ('model', model)\n",
+ " ])\n",
+ " \n",
+ " param_grid = param_grids_classification[name]\n",
+ " grid_search = RandomizedSearchCV(pipeline, param_grid, cv=5, scoring='f1', n_jobs=-1)\n",
+ " grid_search.fit(X_train, y_train)\n",
+ "\n",
+ " # Лучшая модель\n",
+ " best_model = grid_search.best_estimator_\n",
+ " y_pred = best_model.predict(X_test)\n",
+ "\n",
+ " # Метрики\n",
+ " acc = accuracy_score(y_test, y_pred)\n",
+ " f1 = f1_score(y_test, y_pred)\n",
+ "\n",
+ " # Вычисление матрицы ошибок\n",
+ " c_matrix = confusion_matrix(y_test, y_pred)\n",
+ "\n",
+ " # Сохранение результатов\n",
+ " results_classification[name] = {\n",
+ " \"Best Params\": grid_search.best_params_,\n",
+ " \"Accuracy\": acc,\n",
+ " \"F1 Score\": f1,\n",
+ " \"Confusion_matrix\": c_matrix\n",
+ " }\n",
+ "\n",
+ "# Печать результатов\n",
+ "for name, metrics in results_classification.items():\n",
+ " print(f\"\\nModel: {name}\")\n",
+ " for metric, value in metrics.items():\n",
+ " print(f\"{metric}: {value}\")"
+ ]
}
],
"metadata": {