From 3863672121382f4b167c410aef759d36b4d10b11 Mon Sep 17 00:00:00 2001 From: dex_moth Date: Sat, 21 Dec 2024 12:33:06 +0400 Subject: [PATCH] correct lab4 --- lab_4/Lab4.ipynb | 256 +++++++++++++---------------------------------- 1 file changed, 70 insertions(+), 186 deletions(-) diff --git a/lab_4/Lab4.ipynb b/lab_4/Lab4.ipynb index 0b8116e..62032e7 100644 --- a/lab_4/Lab4.ipynb +++ b/lab_4/Lab4.ipynb @@ -23,6 +23,13 @@ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from matplotlib.ticker import FuncFormatter\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.linear_model import Lasso\n", + "from sklearn.ensemble import GradientBoostingRegressor\n", + "from sklearn.neighbors import KNeighborsRegressor\n", "\n", "df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n", "print(df.columns)" @@ -293,6 +300,50 @@ "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Создание конвейера\n", + "\n", + "# Обработаем данные\n", + "# Определим категориальные и числовые признаки\n", + "categorical_features = ['Gender', 'City', 'Dietary Habits', 'Degree', 'Have you ever had suicidal thoughts ?', 'Profession', 'Family History of Mental Illness', 'Sleep Duration']\n", + "numerical_features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress']\n", + "\n", + "categorical_transformer = Pipeline(steps=[\n", + " ('onehot', OneHotEncoder(handle_unknown='ignore'))\n", + "])\n", + "\n", + "numerical_transformer = Pipeline(steps=[\n", + " ('scaler', StandardScaler())\n", + "])\n", + "\n", + "preprocessor = ColumnTransformer(\n", + " transformers=[\n", + " ('num', numerical_transformer, numerical_features),\n", + " ('cat', categorical_transformer, categorical_features)\n", + " ])\n", + "\n", + "# Построим модели\n", + "pipeline_lasso = Pipeline(steps=[\n", + " ('preprocessor', preprocessor),\n", + " ('model', Lasso())\n", + "])\n", + "\n", + "pipeline_gb = Pipeline(steps=[\n", + " ('preprocessor', preprocessor),\n", + " ('model', GradientBoostingRegressor())\n", + "])\n", + "\n", + "pipeline_knn = Pipeline(steps=[\n", + " ('preprocessor', preprocessor),\n", + " ('model', KNeighborsRegressor())\n", + "])" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -302,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -310,7 +361,7 @@ "output_type": "stream", "text": [ "Лучшие гиперпараметры для Lasso:\n", - "{'alpha': 0.01, 'fit_intercept': False}\n" + "{'model__alpha': 0.01, 'model__fit_intercept': False}\n" ] } ], @@ -318,8 +369,8 @@ "from sklearn.linear_model import Lasso\n", "\n", "param_grid_lasso = {\n", - " 'alpha': [0.01, 0.1, 1.0, 10.0],\n", - " 'fit_intercept': [True, False],\n", + " 'model__alpha': [0.01, 0.1, 1.0, 10.0],\n", + " 'model__fit_intercept': [True, False],\n", "}\n", "\n", "# Создание объекта GridSearchCV\n", @@ -347,193 +398,28 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 2, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n", - "1215 fits failed out of a total of 3645.\n", - "The score on these train-test partitions for these parameters will be set to nan.\n", - "If these failures are not expected, you can try to debug them by setting error_score='raise'.\n", - "\n", - "Below are more details about the failures:\n", - "--------------------------------------------------------------------------------\n", - "978 fits failed with the following error:\n", - "Traceback (most recent call last):\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n", - " estimator.fit(X_train, y_train, **fit_params)\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n", - " estimator._validate_params()\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n", - " validate_parameter_constraints(\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n", - " raise InvalidParameterError(\n", - "sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.\n", - "\n", - "--------------------------------------------------------------------------------\n", - "237 fits failed with the following error:\n", - "Traceback (most recent call last):\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n", - " estimator.fit(X_train, y_train, **fit_params)\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n", - " estimator._validate_params()\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n", - " validate_parameter_constraints(\n", - " File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n", - " raise InvalidParameterError(\n", - "sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.\n", - "\n", - " warnings.warn(some_fits_failed_message, FitFailedWarning)\n", - "e:\\AIM1.5\\Scripts\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n", - " _data = np.array(data, dtype=dtype, copy=copy,\n", - "e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan -0.18767441 -0.15799837 -0.13080278\n", - " -0.18762913 -0.15792709 -0.13056114 -0.18792038 -0.15737146 -0.130218\n", - " -0.18725961 -0.157967 -0.13047453 -0.18766583 -0.15779565 -0.13094863\n", - " -0.18798705 -0.15693978 -0.13061215 -0.18766317 -0.15746848 -0.13072918\n", - " -0.18864158 -0.15666133 -0.13095037 -0.18817206 -0.15805489 -0.13086126\n", - " -0.18707465 -0.15864932 -0.13104947 -0.18818902 -0.15828572 -0.13063871\n", - " -0.18701628 -0.15853864 -0.13019458 -0.18740927 -0.15836397 -0.13065455\n", - " -0.18768748 -0.15828297 -0.1309458 -0.18845004 -0.15696395 -0.13023062\n", - " -0.18754854 -0.15899615 -0.13061707 -0.18831427 -0.15819939 -0.13096524\n", - " -0.18662963 -0.15815869 -0.13089186 nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " -0.1758914 -0.1442684 -0.12093344 -0.1758927 -0.14423731 -0.12084543\n", - " -0.17573339 -0.14419842 -0.12076166 -0.17512045 -0.14435454 -0.1207299\n", - " -0.17669645 -0.14397965 -0.12087019 -0.17605424 -0.1438664 -0.12091068\n", - " -0.17582192 -0.1443651 -0.12097165 -0.17588422 -0.14421003 -0.12081764\n", - " -0.17522742 -0.14424357 -0.12086484 -0.17530986 -0.14433713 -0.12091757\n", - " -0.17565647 -0.14408902 -0.12075918 -0.17561884 -0.14426355 -0.12094066\n", - " -0.17522371 -0.1439869 -0.12099023 -0.17619772 -0.14396131 -0.12079667\n", - " -0.17710789 -0.1448419 -0.12087822 -0.17608534 -0.14416684 -0.12087865\n", - " -0.1754675 -0.1442258 -0.12068226 -0.17611334 -0.14433552 -0.12093556\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan -0.16938321 -0.13763002 -0.11703902\n", - " -0.16953091 -0.13736586 -0.11695779 -0.16881837 -0.1375676 -0.11694438\n", - " -0.16927898 -0.13748177 -0.11689982 -0.16921265 -0.13757375 -0.11682524\n", - " -0.16915872 -0.13727377 -0.11694336 -0.16939766 -0.13734972 -0.1167447\n", - " -0.16924214 -0.1373768 -0.11674816 -0.16918278 -0.13746085 -0.1169816\n", - " -0.16927003 -0.13740063 -0.1169564 -0.16916501 -0.13752074 -0.11687641\n", - " -0.16928973 -0.13751536 -0.11697948 -0.16934836 -0.13727436 -0.11693615\n", - " -0.16912453 -0.13748699 -0.11693425 -0.1692788 -0.13750784 -0.11694655\n", - " -0.16919354 -0.13747437 -0.11708782 -0.16940009 -0.13757749 -0.11700586\n", - " -0.1692801 -0.13725384 -0.11684394 nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " -0.11606052 -0.1140225 -0.11403709 -0.11627212 -0.1139982 -0.11402075\n", - " -0.11613561 -0.11407941 -0.11420487 -0.11666225 -0.11462523 -0.11431901\n", - " -0.11604817 -0.11456211 -0.11392092 -0.11609343 -0.11394228 -0.11414071\n", - " -0.11611685 -0.11420178 -0.11405459 -0.11594404 -0.11408614 -0.11391662\n", - " -0.11590886 -0.11396465 -0.11389125 -0.11616694 -0.11441846 -0.11417015\n", - " -0.11617368 -0.11429765 -0.1139636 -0.11616763 -0.11433984 -0.11412121\n", - " -0.11625618 -0.11402999 -0.11419791 -0.11613603 -0.114206 -0.11423922\n", - " -0.1160801 -0.11431896 -0.11416734 -0.11608923 -0.11455498 -0.11417448\n", - " -0.11605165 -0.11427773 -0.11392205 -0.11606243 -0.11408421 -0.11395292\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan -0.11281447 -0.11245904 -0.11308822\n", - " -0.11256366 -0.11230094 -0.1130767 -0.11282651 -0.1121034 -0.11283479\n", - " -0.11260704 -0.1125136 -0.11288977 -0.11278304 -0.11242278 -0.11268564\n", - " -0.11263359 -0.11236227 -0.11329411 -0.11231603 -0.1124533 -0.11278826\n", - " -0.11291545 -0.11241223 -0.11250702 -0.11246481 -0.11228665 -0.11348916\n", - " -0.11250694 -0.11250274 -0.11298019 -0.11277323 -0.11248601 -0.11301753\n", - " -0.11259486 -0.1124685 -0.11285441 -0.11274424 -0.11232891 -0.11316456\n", - " -0.11274575 -0.11256149 -0.11252293 -0.11293524 -0.11261757 -0.11305628\n", - " -0.11253063 -0.11237109 -0.11278518 -0.1124074 -0.11276905 -0.11296684\n", - " -0.11258689 -0.11228467 -0.11331342 nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " -0.11292265 -0.11395193 -0.11564599 -0.11244356 -0.11338947 -0.1148266\n", - " -0.11295702 -0.11353862 -0.11510521 -0.11244347 -0.11387967 -0.11512396\n", - " -0.11269802 -0.11364442 -0.1151339 -0.11238356 -0.11364301 -0.11496543\n", - " -0.11229193 -0.11340926 -0.11550744 -0.11215818 -0.11367944 -0.11552889\n", - " -0.11240305 -0.11352309 -0.115412 -0.1128402 -0.11338749 -0.1153551\n", - " -0.11250042 -0.11347275 -0.11548445 -0.11271132 -0.11377527 -0.11558066\n", - " -0.11318598 -0.11325792 -0.11499103 -0.11253099 -0.1129829 -0.11530949\n", - " -0.11239074 -0.11329625 -0.11544761 -0.11262484 -0.11323392 -0.1151936\n", - " -0.11253889 -0.11382403 -0.11511129 -0.11250854 -0.11339898 -0.11536332\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan -0.11542253 -0.11498664 -0.11428517\n", - " -0.11503783 -0.11473447 -0.11458687 -0.11483866 -0.1154254 -0.11479037\n", - " -0.11533015 -0.11515195 -0.11460571 -0.11563491 -0.11433835 -0.11437413\n", - " -0.11510849 -0.11472156 -0.11516494 -0.11545009 -0.115001 -0.11479743\n", - " -0.11461761 -0.11537461 -0.11497109 -0.1155148 -0.11567353 -0.11431184\n", - " -0.11546067 -0.11462564 -0.11450721 -0.11511 -0.11487988 -0.11466523\n", - " -0.11585756 -0.11462611 -0.11433121 -0.11538152 -0.11463425 -0.11527088\n", - " -0.11509145 -0.11493588 -0.11484324 -0.11528905 -0.11426327 -0.11476508\n", - " -0.11499562 -0.11451299 -0.11466765 -0.11525918 -0.11469718 -0.11476983\n", - " -0.11467865 -0.1145067 -0.11479425 nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " -0.11352917 -0.1145882 -0.11643688 -0.11418115 -0.11442858 -0.11635549\n", - " -0.11408502 -0.11458383 -0.1163013 -0.1135842 -0.11453566 -0.11575264\n", - " -0.11341863 -0.11481638 -0.11635685 -0.1132144 -0.11438018 -0.11666005\n", - " -0.11311482 -0.11500883 -0.11594984 -0.11409228 -0.11464061 -0.1158012\n", - " -0.11389399 -0.11454081 -0.1157428 -0.11333869 -0.11438896 -0.11676006\n", - " -0.11382523 -0.11443669 -0.11606569 -0.11424726 -0.11464652 -0.11608159\n", - " -0.11396605 -0.11473188 -0.1167532 -0.1136805 -0.11455875 -0.11615814\n", - " -0.11372286 -0.11442829 -0.11590895 -0.1136509 -0.11368863 -0.11660073\n", - " -0.1136605 -0.1141187 -0.11613806 -0.11326355 -0.11427399 -0.11676148\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan nan nan nan\n", - " nan nan nan -0.11573534 -0.11897501 -0.1226239\n", - " -0.1162633 -0.11939573 -0.12255715 -0.11636411 -0.11878021 -0.12306277\n", - " -0.11535113 -0.11813967 -0.1230085 -0.11594119 -0.11812955 -0.12217928\n", - " -0.11523023 -0.11843291 -0.12228252 -0.1159457 -0.11840108 -0.12181337\n", - " -0.11600134 -0.11790484 -0.12203724 -0.11579998 -0.11787918 -0.12317219\n", - " -0.11578704 -0.11837798 -0.12379234 -0.1155279 -0.11865384 -0.12319867\n", - " -0.11597008 -0.11886814 -0.12291788 -0.1162282 -0.11918752 -0.12363613\n", - " -0.11571473 -0.11805225 -0.12250506 -0.11640247 -0.11823175 -0.1226976\n", - " -0.11571549 -0.11813327 -0.12229009 -0.11621545 -0.11793769 -0.1229533\n", - " -0.11528287 -0.1183919 -0.12121653]\n", - " warnings.warn(\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ "Лучшие гиперпараметры для Gradient Boosting:\n", - "{'learning_rate': 0.1, 'max_depth': 5, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 100}\n" + "{'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__min_samples_leaf': 2, 'model__min_samples_split': 5, 'model__n_estimators': 100}\n" ] } ], "source": [ - "\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "\n", "param_grid_gb = {\n", - " 'n_estimators': [50, 100, 200],\n", - " 'learning_rate': [0.01, 0.1, 0.2],\n", - " 'max_depth': [3, 5, 7],\n", - " 'min_samples_split': [2, 5, 10],\n", - " 'min_samples_leaf': [1, 2, 4],\n", - " 'max_features': ['auto', 'sqrt', 'log2']\n", + " 'model__n_estimators': [50, 100, 200],\n", + " 'model__learning_rate': [0.01, 0.1, 0.2],\n", + " 'model__max_depth': [3, 5, 7],\n", + " 'model__min_samples_split': [2, 5, 10],\n", + " 'model__min_samples_leaf': [1, 2, 4],\n", + " 'model__max_features': ['auto', 'sqrt', 'log2']\n", "}\n", "\n", "grid_search_gb = GridSearchCV(\n", @@ -577,10 +463,10 @@ "from sklearn.model_selection import GridSearchCV\n", "\n", "param_grid_knn = {\n", - " 'n_neighbors': [3, 5, 7, 10],\n", - " 'weights': ['uniform', 'distance'],\n", - " 'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n", - " 'p': [1, 2]\n", + " 'model__n_neighbors': [3, 5, 7, 10],\n", + " 'model__weights': ['uniform', 'distance'],\n", + " 'model__algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n", + " 'model__p': [1, 2]\n", "}\n", "\n", "grid_search_knn = GridSearchCV(\n", @@ -611,11 +497,9 @@ "metadata": {}, "outputs": [], "source": [ - "y_pred = model.predict(x_test)\n", - "y_pred_forest = model_forest.predict(x_test)\n", - "y_pred_lasso = model_lasso.predict(x_test)\n", - "y_pred_gb = model_gb.predict(x_test)\n", - "y_pred_neighbors = model_knn.predict(x_test)" + "y_pred_lasso = grid_search_lasso.predict(x_test)\n", + "y_pred_forest = grid_search_gb.predict(x_test)\n", + "y_pred_neighbors = grid_search_knn.predict(x_test)" ] }, {