correct lab4

This commit is contained in:
dex_moth 2024-12-21 12:33:06 +04:00
parent 2fca9fd006
commit 3863672121

View File

@ -23,6 +23,13 @@
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import seaborn as sns\n", "import seaborn as sns\n",
"from matplotlib.ticker import FuncFormatter\n", "from matplotlib.ticker import FuncFormatter\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.linear_model import Lasso\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.neighbors import KNeighborsRegressor\n",
"\n", "\n",
"df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n", "df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n",
"print(df.columns)" "print(df.columns)"
@ -293,6 +300,50 @@
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)" "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Создание конвейера\n",
"\n",
"# Обработаем данные\n",
"# Определим категориальные и числовые признаки\n",
"categorical_features = ['Gender', 'City', 'Dietary Habits', 'Degree', 'Have you ever had suicidal thoughts ?', 'Profession', 'Family History of Mental Illness', 'Sleep Duration']\n",
"numerical_features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress']\n",
"\n",
"categorical_transformer = Pipeline(steps=[\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore'))\n",
"])\n",
"\n",
"numerical_transformer = Pipeline(steps=[\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" transformers=[\n",
" ('num', numerical_transformer, numerical_features),\n",
" ('cat', categorical_transformer, categorical_features)\n",
" ])\n",
"\n",
"# Построим модели\n",
"pipeline_lasso = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', Lasso())\n",
"])\n",
"\n",
"pipeline_gb = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', GradientBoostingRegressor())\n",
"])\n",
"\n",
"pipeline_knn = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', KNeighborsRegressor())\n",
"])"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -302,7 +353,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -310,7 +361,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Лучшие гиперпараметры для Lasso:\n", "Лучшие гиперпараметры для Lasso:\n",
"{'alpha': 0.01, 'fit_intercept': False}\n" "{'model__alpha': 0.01, 'model__fit_intercept': False}\n"
] ]
} }
], ],
@ -318,8 +369,8 @@
"from sklearn.linear_model import Lasso\n", "from sklearn.linear_model import Lasso\n",
"\n", "\n",
"param_grid_lasso = {\n", "param_grid_lasso = {\n",
" 'alpha': [0.01, 0.1, 1.0, 10.0],\n", " 'model__alpha': [0.01, 0.1, 1.0, 10.0],\n",
" 'fit_intercept': [True, False],\n", " 'model__fit_intercept': [True, False],\n",
"}\n", "}\n",
"\n", "\n",
"# Создание объекта GridSearchCV\n", "# Создание объекта GridSearchCV\n",
@ -347,193 +398,28 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n",
"1215 fits failed out of a total of 3645.\n",
"The score on these train-test partitions for these parameters will be set to nan.\n",
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
"\n",
"Below are more details about the failures:\n",
"--------------------------------------------------------------------------------\n",
"978 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
" estimator._validate_params()\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
" validate_parameter_constraints(\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
" raise InvalidParameterError(\n",
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"237 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
" estimator._validate_params()\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
" validate_parameter_constraints(\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
" raise InvalidParameterError(\n",
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.\n",
"\n",
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n",
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.18767441 -0.15799837 -0.13080278\n",
" -0.18762913 -0.15792709 -0.13056114 -0.18792038 -0.15737146 -0.130218\n",
" -0.18725961 -0.157967 -0.13047453 -0.18766583 -0.15779565 -0.13094863\n",
" -0.18798705 -0.15693978 -0.13061215 -0.18766317 -0.15746848 -0.13072918\n",
" -0.18864158 -0.15666133 -0.13095037 -0.18817206 -0.15805489 -0.13086126\n",
" -0.18707465 -0.15864932 -0.13104947 -0.18818902 -0.15828572 -0.13063871\n",
" -0.18701628 -0.15853864 -0.13019458 -0.18740927 -0.15836397 -0.13065455\n",
" -0.18768748 -0.15828297 -0.1309458 -0.18845004 -0.15696395 -0.13023062\n",
" -0.18754854 -0.15899615 -0.13061707 -0.18831427 -0.15819939 -0.13096524\n",
" -0.18662963 -0.15815869 -0.13089186 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.1758914 -0.1442684 -0.12093344 -0.1758927 -0.14423731 -0.12084543\n",
" -0.17573339 -0.14419842 -0.12076166 -0.17512045 -0.14435454 -0.1207299\n",
" -0.17669645 -0.14397965 -0.12087019 -0.17605424 -0.1438664 -0.12091068\n",
" -0.17582192 -0.1443651 -0.12097165 -0.17588422 -0.14421003 -0.12081764\n",
" -0.17522742 -0.14424357 -0.12086484 -0.17530986 -0.14433713 -0.12091757\n",
" -0.17565647 -0.14408902 -0.12075918 -0.17561884 -0.14426355 -0.12094066\n",
" -0.17522371 -0.1439869 -0.12099023 -0.17619772 -0.14396131 -0.12079667\n",
" -0.17710789 -0.1448419 -0.12087822 -0.17608534 -0.14416684 -0.12087865\n",
" -0.1754675 -0.1442258 -0.12068226 -0.17611334 -0.14433552 -0.12093556\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.16938321 -0.13763002 -0.11703902\n",
" -0.16953091 -0.13736586 -0.11695779 -0.16881837 -0.1375676 -0.11694438\n",
" -0.16927898 -0.13748177 -0.11689982 -0.16921265 -0.13757375 -0.11682524\n",
" -0.16915872 -0.13727377 -0.11694336 -0.16939766 -0.13734972 -0.1167447\n",
" -0.16924214 -0.1373768 -0.11674816 -0.16918278 -0.13746085 -0.1169816\n",
" -0.16927003 -0.13740063 -0.1169564 -0.16916501 -0.13752074 -0.11687641\n",
" -0.16928973 -0.13751536 -0.11697948 -0.16934836 -0.13727436 -0.11693615\n",
" -0.16912453 -0.13748699 -0.11693425 -0.1692788 -0.13750784 -0.11694655\n",
" -0.16919354 -0.13747437 -0.11708782 -0.16940009 -0.13757749 -0.11700586\n",
" -0.1692801 -0.13725384 -0.11684394 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11606052 -0.1140225 -0.11403709 -0.11627212 -0.1139982 -0.11402075\n",
" -0.11613561 -0.11407941 -0.11420487 -0.11666225 -0.11462523 -0.11431901\n",
" -0.11604817 -0.11456211 -0.11392092 -0.11609343 -0.11394228 -0.11414071\n",
" -0.11611685 -0.11420178 -0.11405459 -0.11594404 -0.11408614 -0.11391662\n",
" -0.11590886 -0.11396465 -0.11389125 -0.11616694 -0.11441846 -0.11417015\n",
" -0.11617368 -0.11429765 -0.1139636 -0.11616763 -0.11433984 -0.11412121\n",
" -0.11625618 -0.11402999 -0.11419791 -0.11613603 -0.114206 -0.11423922\n",
" -0.1160801 -0.11431896 -0.11416734 -0.11608923 -0.11455498 -0.11417448\n",
" -0.11605165 -0.11427773 -0.11392205 -0.11606243 -0.11408421 -0.11395292\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11281447 -0.11245904 -0.11308822\n",
" -0.11256366 -0.11230094 -0.1130767 -0.11282651 -0.1121034 -0.11283479\n",
" -0.11260704 -0.1125136 -0.11288977 -0.11278304 -0.11242278 -0.11268564\n",
" -0.11263359 -0.11236227 -0.11329411 -0.11231603 -0.1124533 -0.11278826\n",
" -0.11291545 -0.11241223 -0.11250702 -0.11246481 -0.11228665 -0.11348916\n",
" -0.11250694 -0.11250274 -0.11298019 -0.11277323 -0.11248601 -0.11301753\n",
" -0.11259486 -0.1124685 -0.11285441 -0.11274424 -0.11232891 -0.11316456\n",
" -0.11274575 -0.11256149 -0.11252293 -0.11293524 -0.11261757 -0.11305628\n",
" -0.11253063 -0.11237109 -0.11278518 -0.1124074 -0.11276905 -0.11296684\n",
" -0.11258689 -0.11228467 -0.11331342 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11292265 -0.11395193 -0.11564599 -0.11244356 -0.11338947 -0.1148266\n",
" -0.11295702 -0.11353862 -0.11510521 -0.11244347 -0.11387967 -0.11512396\n",
" -0.11269802 -0.11364442 -0.1151339 -0.11238356 -0.11364301 -0.11496543\n",
" -0.11229193 -0.11340926 -0.11550744 -0.11215818 -0.11367944 -0.11552889\n",
" -0.11240305 -0.11352309 -0.115412 -0.1128402 -0.11338749 -0.1153551\n",
" -0.11250042 -0.11347275 -0.11548445 -0.11271132 -0.11377527 -0.11558066\n",
" -0.11318598 -0.11325792 -0.11499103 -0.11253099 -0.1129829 -0.11530949\n",
" -0.11239074 -0.11329625 -0.11544761 -0.11262484 -0.11323392 -0.1151936\n",
" -0.11253889 -0.11382403 -0.11511129 -0.11250854 -0.11339898 -0.11536332\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11542253 -0.11498664 -0.11428517\n",
" -0.11503783 -0.11473447 -0.11458687 -0.11483866 -0.1154254 -0.11479037\n",
" -0.11533015 -0.11515195 -0.11460571 -0.11563491 -0.11433835 -0.11437413\n",
" -0.11510849 -0.11472156 -0.11516494 -0.11545009 -0.115001 -0.11479743\n",
" -0.11461761 -0.11537461 -0.11497109 -0.1155148 -0.11567353 -0.11431184\n",
" -0.11546067 -0.11462564 -0.11450721 -0.11511 -0.11487988 -0.11466523\n",
" -0.11585756 -0.11462611 -0.11433121 -0.11538152 -0.11463425 -0.11527088\n",
" -0.11509145 -0.11493588 -0.11484324 -0.11528905 -0.11426327 -0.11476508\n",
" -0.11499562 -0.11451299 -0.11466765 -0.11525918 -0.11469718 -0.11476983\n",
" -0.11467865 -0.1145067 -0.11479425 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11352917 -0.1145882 -0.11643688 -0.11418115 -0.11442858 -0.11635549\n",
" -0.11408502 -0.11458383 -0.1163013 -0.1135842 -0.11453566 -0.11575264\n",
" -0.11341863 -0.11481638 -0.11635685 -0.1132144 -0.11438018 -0.11666005\n",
" -0.11311482 -0.11500883 -0.11594984 -0.11409228 -0.11464061 -0.1158012\n",
" -0.11389399 -0.11454081 -0.1157428 -0.11333869 -0.11438896 -0.11676006\n",
" -0.11382523 -0.11443669 -0.11606569 -0.11424726 -0.11464652 -0.11608159\n",
" -0.11396605 -0.11473188 -0.1167532 -0.1136805 -0.11455875 -0.11615814\n",
" -0.11372286 -0.11442829 -0.11590895 -0.1136509 -0.11368863 -0.11660073\n",
" -0.1136605 -0.1141187 -0.11613806 -0.11326355 -0.11427399 -0.11676148\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11573534 -0.11897501 -0.1226239\n",
" -0.1162633 -0.11939573 -0.12255715 -0.11636411 -0.11878021 -0.12306277\n",
" -0.11535113 -0.11813967 -0.1230085 -0.11594119 -0.11812955 -0.12217928\n",
" -0.11523023 -0.11843291 -0.12228252 -0.1159457 -0.11840108 -0.12181337\n",
" -0.11600134 -0.11790484 -0.12203724 -0.11579998 -0.11787918 -0.12317219\n",
" -0.11578704 -0.11837798 -0.12379234 -0.1155279 -0.11865384 -0.12319867\n",
" -0.11597008 -0.11886814 -0.12291788 -0.1162282 -0.11918752 -0.12363613\n",
" -0.11571473 -0.11805225 -0.12250506 -0.11640247 -0.11823175 -0.1226976\n",
" -0.11571549 -0.11813327 -0.12229009 -0.11621545 -0.11793769 -0.1229533\n",
" -0.11528287 -0.1183919 -0.12121653]\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Лучшие гиперпараметры для Gradient Boosting:\n", "Лучшие гиперпараметры для Gradient Boosting:\n",
"{'learning_rate': 0.1, 'max_depth': 5, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 100}\n" "{'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__min_samples_leaf': 2, 'model__min_samples_split': 5, 'model__n_estimators': 100}\n"
] ]
} }
], ],
"source": [ "source": [
"\n",
"from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.ensemble import GradientBoostingRegressor\n",
"\n", "\n",
"param_grid_gb = {\n", "param_grid_gb = {\n",
" 'n_estimators': [50, 100, 200],\n", " 'model__n_estimators': [50, 100, 200],\n",
" 'learning_rate': [0.01, 0.1, 0.2],\n", " 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'max_depth': [3, 5, 7],\n", " 'model__max_depth': [3, 5, 7],\n",
" 'min_samples_split': [2, 5, 10],\n", " 'model__min_samples_split': [2, 5, 10],\n",
" 'min_samples_leaf': [1, 2, 4],\n", " 'model__min_samples_leaf': [1, 2, 4],\n",
" 'max_features': ['auto', 'sqrt', 'log2']\n", " 'model__max_features': ['auto', 'sqrt', 'log2']\n",
"}\n", "}\n",
"\n", "\n",
"grid_search_gb = GridSearchCV(\n", "grid_search_gb = GridSearchCV(\n",
@ -577,10 +463,10 @@
"from sklearn.model_selection import GridSearchCV\n", "from sklearn.model_selection import GridSearchCV\n",
"\n", "\n",
"param_grid_knn = {\n", "param_grid_knn = {\n",
" 'n_neighbors': [3, 5, 7, 10],\n", " 'model__n_neighbors': [3, 5, 7, 10],\n",
" 'weights': ['uniform', 'distance'],\n", " 'model__weights': ['uniform', 'distance'],\n",
" 'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n", " 'model__algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n",
" 'p': [1, 2]\n", " 'model__p': [1, 2]\n",
"}\n", "}\n",
"\n", "\n",
"grid_search_knn = GridSearchCV(\n", "grid_search_knn = GridSearchCV(\n",
@ -611,11 +497,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_pred = model.predict(x_test)\n", "y_pred_lasso = grid_search_lasso.predict(x_test)\n",
"y_pred_forest = model_forest.predict(x_test)\n", "y_pred_forest = grid_search_gb.predict(x_test)\n",
"y_pred_lasso = model_lasso.predict(x_test)\n", "y_pred_neighbors = grid_search_knn.predict(x_test)"
"y_pred_gb = model_gb.predict(x_test)\n",
"y_pred_neighbors = model_knn.predict(x_test)"
] ]
}, },
{ {