correct lab4
This commit is contained in:
parent
2fca9fd006
commit
3863672121
256
lab_4/Lab4.ipynb
256
lab_4/Lab4.ipynb
@ -23,6 +23,13 @@
|
|||||||
"import matplotlib.pyplot as plt\n",
|
"import matplotlib.pyplot as plt\n",
|
||||||
"import seaborn as sns\n",
|
"import seaborn as sns\n",
|
||||||
"from matplotlib.ticker import FuncFormatter\n",
|
"from matplotlib.ticker import FuncFormatter\n",
|
||||||
|
"from sklearn.pipeline import Pipeline\n",
|
||||||
|
"from sklearn.compose import ColumnTransformer\n",
|
||||||
|
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||||||
|
"from sklearn.model_selection import GridSearchCV\n",
|
||||||
|
"from sklearn.linear_model import Lasso\n",
|
||||||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
||||||
|
"from sklearn.neighbors import KNeighborsRegressor\n",
|
||||||
"\n",
|
"\n",
|
||||||
"df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n",
|
"df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n",
|
||||||
"print(df.columns)"
|
"print(df.columns)"
|
||||||
@ -293,6 +300,50 @@
|
|||||||
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)"
|
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Создание конвейера\n",
|
||||||
|
"\n",
|
||||||
|
"# Обработаем данные\n",
|
||||||
|
"# Определим категориальные и числовые признаки\n",
|
||||||
|
"categorical_features = ['Gender', 'City', 'Dietary Habits', 'Degree', 'Have you ever had suicidal thoughts ?', 'Profession', 'Family History of Mental Illness', 'Sleep Duration']\n",
|
||||||
|
"numerical_features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress']\n",
|
||||||
|
"\n",
|
||||||
|
"categorical_transformer = Pipeline(steps=[\n",
|
||||||
|
" ('onehot', OneHotEncoder(handle_unknown='ignore'))\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"numerical_transformer = Pipeline(steps=[\n",
|
||||||
|
" ('scaler', StandardScaler())\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"preprocessor = ColumnTransformer(\n",
|
||||||
|
" transformers=[\n",
|
||||||
|
" ('num', numerical_transformer, numerical_features),\n",
|
||||||
|
" ('cat', categorical_transformer, categorical_features)\n",
|
||||||
|
" ])\n",
|
||||||
|
"\n",
|
||||||
|
"# Построим модели\n",
|
||||||
|
"pipeline_lasso = Pipeline(steps=[\n",
|
||||||
|
" ('preprocessor', preprocessor),\n",
|
||||||
|
" ('model', Lasso())\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"pipeline_gb = Pipeline(steps=[\n",
|
||||||
|
" ('preprocessor', preprocessor),\n",
|
||||||
|
" ('model', GradientBoostingRegressor())\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"pipeline_knn = Pipeline(steps=[\n",
|
||||||
|
" ('preprocessor', preprocessor),\n",
|
||||||
|
" ('model', KNeighborsRegressor())\n",
|
||||||
|
"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@ -302,7 +353,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
@ -310,7 +361,7 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Лучшие гиперпараметры для Lasso:\n",
|
"Лучшие гиперпараметры для Lasso:\n",
|
||||||
"{'alpha': 0.01, 'fit_intercept': False}\n"
|
"{'model__alpha': 0.01, 'model__fit_intercept': False}\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -318,8 +369,8 @@
|
|||||||
"from sklearn.linear_model import Lasso\n",
|
"from sklearn.linear_model import Lasso\n",
|
||||||
"\n",
|
"\n",
|
||||||
"param_grid_lasso = {\n",
|
"param_grid_lasso = {\n",
|
||||||
" 'alpha': [0.01, 0.1, 1.0, 10.0],\n",
|
" 'model__alpha': [0.01, 0.1, 1.0, 10.0],\n",
|
||||||
" 'fit_intercept': [True, False],\n",
|
" 'model__fit_intercept': [True, False],\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Создание объекта GridSearchCV\n",
|
"# Создание объекта GridSearchCV\n",
|
||||||
@ -347,193 +398,28 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 2,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n",
|
|
||||||
"1215 fits failed out of a total of 3645.\n",
|
|
||||||
"The score on these train-test partitions for these parameters will be set to nan.\n",
|
|
||||||
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
|
|
||||||
"\n",
|
|
||||||
"Below are more details about the failures:\n",
|
|
||||||
"--------------------------------------------------------------------------------\n",
|
|
||||||
"978 fits failed with the following error:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
|
|
||||||
" estimator.fit(X_train, y_train, **fit_params)\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
|
|
||||||
" estimator._validate_params()\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
|
|
||||||
" validate_parameter_constraints(\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
|
|
||||||
" raise InvalidParameterError(\n",
|
|
||||||
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.\n",
|
|
||||||
"\n",
|
|
||||||
"--------------------------------------------------------------------------------\n",
|
|
||||||
"237 fits failed with the following error:\n",
|
|
||||||
"Traceback (most recent call last):\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
|
|
||||||
" estimator.fit(X_train, y_train, **fit_params)\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
|
|
||||||
" estimator._validate_params()\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
|
|
||||||
" validate_parameter_constraints(\n",
|
|
||||||
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
|
|
||||||
" raise InvalidParameterError(\n",
|
|
||||||
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.\n",
|
|
||||||
"\n",
|
|
||||||
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
|
|
||||||
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
|
|
||||||
" _data = np.array(data, dtype=dtype, copy=copy,\n",
|
|
||||||
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan -0.18767441 -0.15799837 -0.13080278\n",
|
|
||||||
" -0.18762913 -0.15792709 -0.13056114 -0.18792038 -0.15737146 -0.130218\n",
|
|
||||||
" -0.18725961 -0.157967 -0.13047453 -0.18766583 -0.15779565 -0.13094863\n",
|
|
||||||
" -0.18798705 -0.15693978 -0.13061215 -0.18766317 -0.15746848 -0.13072918\n",
|
|
||||||
" -0.18864158 -0.15666133 -0.13095037 -0.18817206 -0.15805489 -0.13086126\n",
|
|
||||||
" -0.18707465 -0.15864932 -0.13104947 -0.18818902 -0.15828572 -0.13063871\n",
|
|
||||||
" -0.18701628 -0.15853864 -0.13019458 -0.18740927 -0.15836397 -0.13065455\n",
|
|
||||||
" -0.18768748 -0.15828297 -0.1309458 -0.18845004 -0.15696395 -0.13023062\n",
|
|
||||||
" -0.18754854 -0.15899615 -0.13061707 -0.18831427 -0.15819939 -0.13096524\n",
|
|
||||||
" -0.18662963 -0.15815869 -0.13089186 nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" -0.1758914 -0.1442684 -0.12093344 -0.1758927 -0.14423731 -0.12084543\n",
|
|
||||||
" -0.17573339 -0.14419842 -0.12076166 -0.17512045 -0.14435454 -0.1207299\n",
|
|
||||||
" -0.17669645 -0.14397965 -0.12087019 -0.17605424 -0.1438664 -0.12091068\n",
|
|
||||||
" -0.17582192 -0.1443651 -0.12097165 -0.17588422 -0.14421003 -0.12081764\n",
|
|
||||||
" -0.17522742 -0.14424357 -0.12086484 -0.17530986 -0.14433713 -0.12091757\n",
|
|
||||||
" -0.17565647 -0.14408902 -0.12075918 -0.17561884 -0.14426355 -0.12094066\n",
|
|
||||||
" -0.17522371 -0.1439869 -0.12099023 -0.17619772 -0.14396131 -0.12079667\n",
|
|
||||||
" -0.17710789 -0.1448419 -0.12087822 -0.17608534 -0.14416684 -0.12087865\n",
|
|
||||||
" -0.1754675 -0.1442258 -0.12068226 -0.17611334 -0.14433552 -0.12093556\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan -0.16938321 -0.13763002 -0.11703902\n",
|
|
||||||
" -0.16953091 -0.13736586 -0.11695779 -0.16881837 -0.1375676 -0.11694438\n",
|
|
||||||
" -0.16927898 -0.13748177 -0.11689982 -0.16921265 -0.13757375 -0.11682524\n",
|
|
||||||
" -0.16915872 -0.13727377 -0.11694336 -0.16939766 -0.13734972 -0.1167447\n",
|
|
||||||
" -0.16924214 -0.1373768 -0.11674816 -0.16918278 -0.13746085 -0.1169816\n",
|
|
||||||
" -0.16927003 -0.13740063 -0.1169564 -0.16916501 -0.13752074 -0.11687641\n",
|
|
||||||
" -0.16928973 -0.13751536 -0.11697948 -0.16934836 -0.13727436 -0.11693615\n",
|
|
||||||
" -0.16912453 -0.13748699 -0.11693425 -0.1692788 -0.13750784 -0.11694655\n",
|
|
||||||
" -0.16919354 -0.13747437 -0.11708782 -0.16940009 -0.13757749 -0.11700586\n",
|
|
||||||
" -0.1692801 -0.13725384 -0.11684394 nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" -0.11606052 -0.1140225 -0.11403709 -0.11627212 -0.1139982 -0.11402075\n",
|
|
||||||
" -0.11613561 -0.11407941 -0.11420487 -0.11666225 -0.11462523 -0.11431901\n",
|
|
||||||
" -0.11604817 -0.11456211 -0.11392092 -0.11609343 -0.11394228 -0.11414071\n",
|
|
||||||
" -0.11611685 -0.11420178 -0.11405459 -0.11594404 -0.11408614 -0.11391662\n",
|
|
||||||
" -0.11590886 -0.11396465 -0.11389125 -0.11616694 -0.11441846 -0.11417015\n",
|
|
||||||
" -0.11617368 -0.11429765 -0.1139636 -0.11616763 -0.11433984 -0.11412121\n",
|
|
||||||
" -0.11625618 -0.11402999 -0.11419791 -0.11613603 -0.114206 -0.11423922\n",
|
|
||||||
" -0.1160801 -0.11431896 -0.11416734 -0.11608923 -0.11455498 -0.11417448\n",
|
|
||||||
" -0.11605165 -0.11427773 -0.11392205 -0.11606243 -0.11408421 -0.11395292\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan -0.11281447 -0.11245904 -0.11308822\n",
|
|
||||||
" -0.11256366 -0.11230094 -0.1130767 -0.11282651 -0.1121034 -0.11283479\n",
|
|
||||||
" -0.11260704 -0.1125136 -0.11288977 -0.11278304 -0.11242278 -0.11268564\n",
|
|
||||||
" -0.11263359 -0.11236227 -0.11329411 -0.11231603 -0.1124533 -0.11278826\n",
|
|
||||||
" -0.11291545 -0.11241223 -0.11250702 -0.11246481 -0.11228665 -0.11348916\n",
|
|
||||||
" -0.11250694 -0.11250274 -0.11298019 -0.11277323 -0.11248601 -0.11301753\n",
|
|
||||||
" -0.11259486 -0.1124685 -0.11285441 -0.11274424 -0.11232891 -0.11316456\n",
|
|
||||||
" -0.11274575 -0.11256149 -0.11252293 -0.11293524 -0.11261757 -0.11305628\n",
|
|
||||||
" -0.11253063 -0.11237109 -0.11278518 -0.1124074 -0.11276905 -0.11296684\n",
|
|
||||||
" -0.11258689 -0.11228467 -0.11331342 nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" -0.11292265 -0.11395193 -0.11564599 -0.11244356 -0.11338947 -0.1148266\n",
|
|
||||||
" -0.11295702 -0.11353862 -0.11510521 -0.11244347 -0.11387967 -0.11512396\n",
|
|
||||||
" -0.11269802 -0.11364442 -0.1151339 -0.11238356 -0.11364301 -0.11496543\n",
|
|
||||||
" -0.11229193 -0.11340926 -0.11550744 -0.11215818 -0.11367944 -0.11552889\n",
|
|
||||||
" -0.11240305 -0.11352309 -0.115412 -0.1128402 -0.11338749 -0.1153551\n",
|
|
||||||
" -0.11250042 -0.11347275 -0.11548445 -0.11271132 -0.11377527 -0.11558066\n",
|
|
||||||
" -0.11318598 -0.11325792 -0.11499103 -0.11253099 -0.1129829 -0.11530949\n",
|
|
||||||
" -0.11239074 -0.11329625 -0.11544761 -0.11262484 -0.11323392 -0.1151936\n",
|
|
||||||
" -0.11253889 -0.11382403 -0.11511129 -0.11250854 -0.11339898 -0.11536332\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan -0.11542253 -0.11498664 -0.11428517\n",
|
|
||||||
" -0.11503783 -0.11473447 -0.11458687 -0.11483866 -0.1154254 -0.11479037\n",
|
|
||||||
" -0.11533015 -0.11515195 -0.11460571 -0.11563491 -0.11433835 -0.11437413\n",
|
|
||||||
" -0.11510849 -0.11472156 -0.11516494 -0.11545009 -0.115001 -0.11479743\n",
|
|
||||||
" -0.11461761 -0.11537461 -0.11497109 -0.1155148 -0.11567353 -0.11431184\n",
|
|
||||||
" -0.11546067 -0.11462564 -0.11450721 -0.11511 -0.11487988 -0.11466523\n",
|
|
||||||
" -0.11585756 -0.11462611 -0.11433121 -0.11538152 -0.11463425 -0.11527088\n",
|
|
||||||
" -0.11509145 -0.11493588 -0.11484324 -0.11528905 -0.11426327 -0.11476508\n",
|
|
||||||
" -0.11499562 -0.11451299 -0.11466765 -0.11525918 -0.11469718 -0.11476983\n",
|
|
||||||
" -0.11467865 -0.1145067 -0.11479425 nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" -0.11352917 -0.1145882 -0.11643688 -0.11418115 -0.11442858 -0.11635549\n",
|
|
||||||
" -0.11408502 -0.11458383 -0.1163013 -0.1135842 -0.11453566 -0.11575264\n",
|
|
||||||
" -0.11341863 -0.11481638 -0.11635685 -0.1132144 -0.11438018 -0.11666005\n",
|
|
||||||
" -0.11311482 -0.11500883 -0.11594984 -0.11409228 -0.11464061 -0.1158012\n",
|
|
||||||
" -0.11389399 -0.11454081 -0.1157428 -0.11333869 -0.11438896 -0.11676006\n",
|
|
||||||
" -0.11382523 -0.11443669 -0.11606569 -0.11424726 -0.11464652 -0.11608159\n",
|
|
||||||
" -0.11396605 -0.11473188 -0.1167532 -0.1136805 -0.11455875 -0.11615814\n",
|
|
||||||
" -0.11372286 -0.11442829 -0.11590895 -0.1136509 -0.11368863 -0.11660073\n",
|
|
||||||
" -0.1136605 -0.1141187 -0.11613806 -0.11326355 -0.11427399 -0.11676148\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan nan nan nan\n",
|
|
||||||
" nan nan nan -0.11573534 -0.11897501 -0.1226239\n",
|
|
||||||
" -0.1162633 -0.11939573 -0.12255715 -0.11636411 -0.11878021 -0.12306277\n",
|
|
||||||
" -0.11535113 -0.11813967 -0.1230085 -0.11594119 -0.11812955 -0.12217928\n",
|
|
||||||
" -0.11523023 -0.11843291 -0.12228252 -0.1159457 -0.11840108 -0.12181337\n",
|
|
||||||
" -0.11600134 -0.11790484 -0.12203724 -0.11579998 -0.11787918 -0.12317219\n",
|
|
||||||
" -0.11578704 -0.11837798 -0.12379234 -0.1155279 -0.11865384 -0.12319867\n",
|
|
||||||
" -0.11597008 -0.11886814 -0.12291788 -0.1162282 -0.11918752 -0.12363613\n",
|
|
||||||
" -0.11571473 -0.11805225 -0.12250506 -0.11640247 -0.11823175 -0.1226976\n",
|
|
||||||
" -0.11571549 -0.11813327 -0.12229009 -0.11621545 -0.11793769 -0.1229533\n",
|
|
||||||
" -0.11528287 -0.1183919 -0.12121653]\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Лучшие гиперпараметры для Gradient Boosting:\n",
|
"Лучшие гиперпараметры для Gradient Boosting:\n",
|
||||||
"{'learning_rate': 0.1, 'max_depth': 5, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 100}\n"
|
"{'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__min_samples_leaf': 2, 'model__min_samples_split': 5, 'model__n_estimators': 100}\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"\n",
|
|
||||||
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
||||||
"\n",
|
"\n",
|
||||||
"param_grid_gb = {\n",
|
"param_grid_gb = {\n",
|
||||||
" 'n_estimators': [50, 100, 200],\n",
|
" 'model__n_estimators': [50, 100, 200],\n",
|
||||||
" 'learning_rate': [0.01, 0.1, 0.2],\n",
|
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
|
||||||
" 'max_depth': [3, 5, 7],\n",
|
" 'model__max_depth': [3, 5, 7],\n",
|
||||||
" 'min_samples_split': [2, 5, 10],\n",
|
" 'model__min_samples_split': [2, 5, 10],\n",
|
||||||
" 'min_samples_leaf': [1, 2, 4],\n",
|
" 'model__min_samples_leaf': [1, 2, 4],\n",
|
||||||
" 'max_features': ['auto', 'sqrt', 'log2']\n",
|
" 'model__max_features': ['auto', 'sqrt', 'log2']\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"grid_search_gb = GridSearchCV(\n",
|
"grid_search_gb = GridSearchCV(\n",
|
||||||
@ -577,10 +463,10 @@
|
|||||||
"from sklearn.model_selection import GridSearchCV\n",
|
"from sklearn.model_selection import GridSearchCV\n",
|
||||||
"\n",
|
"\n",
|
||||||
"param_grid_knn = {\n",
|
"param_grid_knn = {\n",
|
||||||
" 'n_neighbors': [3, 5, 7, 10],\n",
|
" 'model__n_neighbors': [3, 5, 7, 10],\n",
|
||||||
" 'weights': ['uniform', 'distance'],\n",
|
" 'model__weights': ['uniform', 'distance'],\n",
|
||||||
" 'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n",
|
" 'model__algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n",
|
||||||
" 'p': [1, 2]\n",
|
" 'model__p': [1, 2]\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"\n",
|
"\n",
|
||||||
"grid_search_knn = GridSearchCV(\n",
|
"grid_search_knn = GridSearchCV(\n",
|
||||||
@ -611,11 +497,9 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"y_pred = model.predict(x_test)\n",
|
"y_pred_lasso = grid_search_lasso.predict(x_test)\n",
|
||||||
"y_pred_forest = model_forest.predict(x_test)\n",
|
"y_pred_forest = grid_search_gb.predict(x_test)\n",
|
||||||
"y_pred_lasso = model_lasso.predict(x_test)\n",
|
"y_pred_neighbors = grid_search_knn.predict(x_test)"
|
||||||
"y_pred_gb = model_gb.predict(x_test)\n",
|
|
||||||
"y_pred_neighbors = model_knn.predict(x_test)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user