1500 lines
76 KiB
Plaintext
1500 lines
76 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T15:40:10.929503Z",
|
||
"start_time": "2024-12-24T15:40:08.889350Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from sklearn.utils import resample\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.preprocessing import LabelEncoder\n",
|
||
"from sklearn import metrics\n",
|
||
"from imblearn.over_sampling import RandomOverSampler\n",
|
||
"from imblearn.under_sampling import RandomUnderSampler\n",
|
||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
|
||
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"from sklearn.linear_model import SGDClassifier, SGDRegressor\n",
|
||
"from sklearn.metrics import (\n",
|
||
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
|
||
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
|
||
")\n",
|
||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
|
||
"import numpy as np\n",
|
||
"import featuretools as ft\n",
|
||
"from sklearn.metrics import accuracy_score, classification_report\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
|
||
"\n",
|
||
"# Обработка пропущенных значений\n",
|
||
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
|
||
"\n",
|
||
"# Удаление или замена неизвестных категорий\n",
|
||
"# Например, замена 'Other' на 'Unknown' в столбце 'gender'\n",
|
||
"df['gender'] = df['gender'].replace('Other', 'Unknown')\n",
|
||
"\n",
|
||
"# Разделяем классы\n",
|
||
"stroke_0 = df[df['stroke'] == 0]\n",
|
||
"stroke_1 = df[df['stroke'] == 1]\n",
|
||
"\n",
|
||
"# Увеличиваем выборку для 1\n",
|
||
"stroke_1 = resample(stroke_1, replace=True, n_samples=len(stroke_0), random_state=42)\n",
|
||
"\n",
|
||
"# Объединяем классы\n",
|
||
"df = pd.concat([stroke_0, stroke_1])\n",
|
||
"\n",
|
||
"df.info()"
|
||
],
|
||
"id": "db3cd8711d083df",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"Index: 9722 entries, 249 to 134\n",
|
||
"Data columns (total 12 columns):\n",
|
||
" # Column Non-Null Count Dtype \n",
|
||
"--- ------ -------------- ----- \n",
|
||
" 0 id 9722 non-null int64 \n",
|
||
" 1 gender 9722 non-null object \n",
|
||
" 2 age 9722 non-null float64\n",
|
||
" 3 hypertension 9722 non-null int64 \n",
|
||
" 4 heart_disease 9722 non-null int64 \n",
|
||
" 5 ever_married 9722 non-null object \n",
|
||
" 6 work_type 9722 non-null object \n",
|
||
" 7 Residence_type 9722 non-null object \n",
|
||
" 8 avg_glucose_level 9722 non-null float64\n",
|
||
" 9 bmi 9722 non-null float64\n",
|
||
" 10 smoking_status 9722 non-null object \n",
|
||
" 11 stroke 9722 non-null int64 \n",
|
||
"dtypes: float64(3), int64(4), object(5)\n",
|
||
"memory usage: 987.4+ KB\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"Классификация: Предсказать вероятность инсульта на основе данных пациента.\n",
|
||
"\n",
|
||
"Регрессия: Предсказать уровень глюкозы в крови на основе данных пациента."
|
||
],
|
||
"id": "d5f640ac158b69c5"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T15:40:11.000561Z",
|
||
"start_time": "2024-12-24T15:40:10.979999Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Определение целевых переменных\n",
|
||
"X = df.drop('stroke', axis=1)\n",
|
||
"y_class = df['stroke'] # Задача классификации\n",
|
||
"y_reg = df['avg_glucose_level'] # Задача регрессии\n",
|
||
"\n",
|
||
"# Преобразование категориальных переменных\n",
|
||
"categorical_features = ['gender', 'ever_married', 'smoking_status']\n",
|
||
"numerical_features = ['age', 'avg_glucose_level', 'bmi']\n",
|
||
"\n",
|
||
"# Создание ColumnTransformer с обработкой неизвестных категорий\n",
|
||
"preprocessor = ColumnTransformer(\n",
|
||
" transformers=[\n",
|
||
" ('num', StandardScaler(), numerical_features),\n",
|
||
" ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)]) # Используем handle_unknown='ignore'\n",
|
||
"\n",
|
||
"# Разделение данных на обучающую и тестовую выборки\n",
|
||
"X_train, X_test, y_class_train, y_class_test, y_reg_train, y_reg_test = train_test_split(X, y_class, y_reg, test_size=0.2, random_state=42) \n",
|
||
"\n",
|
||
"def estimate_bias_variance(model, X, y):\n",
|
||
" predictions = np.array([model.fit(X, y).predict(X) for _ in range(1000)])\n",
|
||
" bias = np.mean((y - np.mean(predictions, axis=0)) ** 2)\n",
|
||
" variance = np.mean(np.var(predictions, axis=0))\n",
|
||
" return bias, variance"
|
||
],
|
||
"id": "55c612d9a4b90a55",
|
||
"outputs": [],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": "Классификация",
|
||
"id": "273d7304e338f532"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T15:40:50.317809Z",
|
||
"start_time": "2024-12-24T15:40:11.011074Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Задача классификации\n",
|
||
"class_pipeline_rf = Pipeline(steps=[\n",
|
||
" ('preprocessor', preprocessor),\n",
|
||
" ('classifier', RandomForestClassifier(random_state=42))])\n",
|
||
"\n",
|
||
"class_pipeline_sgd = Pipeline(steps=[\n",
|
||
" ('preprocessor', preprocessor),\n",
|
||
" ('classifier', SGDClassifier(loss='log_loss', penalty='l2', random_state=42, max_iter=2000))]) \n",
|
||
"\n",
|
||
"# Настройка гиперпараметров\n",
|
||
"param_grid_class_rf = {\n",
|
||
" 'classifier__n_estimators': [100, 200],\n",
|
||
" 'classifier__max_depth': [None, 10, 20]}\n",
|
||
"\n",
|
||
"param_grid_class_sgd = {\n",
|
||
" 'classifier__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'classifier__learning_rate': ['constant', 'adaptive'],\n",
|
||
" 'classifier__eta0': [0.01, 0.1]}\n",
|
||
"\n",
|
||
"# Поиск гиперпараметров\n",
|
||
"grid_search_class_rf = GridSearchCV(class_pipeline_rf, param_grid_class_rf, cv=5, scoring='accuracy')\n",
|
||
"grid_search_class_rf.fit(X_train, y_class_train)\n",
|
||
"\n",
|
||
"grid_search_class_sgd = GridSearchCV(class_pipeline_sgd, param_grid_class_sgd, cv=5, scoring='accuracy')\n",
|
||
"grid_search_class_sgd.fit(X_train, y_class_train)"
|
||
],
|
||
"id": "956c94b392572508",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('classifier',\n",
|
||
" SGDClassifier(loss='log_loss',\n",
|
||
" max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'classifier__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'classifier__eta0': [0.01, 0.1],\n",
|
||
" 'classifier__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='accuracy')"
|
||
],
|
||
"text/html": [
|
||
"<style>#sk-container-id-1 {\n",
|
||
" /* Definition of color scheme common for light and dark mode */\n",
|
||
" --sklearn-color-text: black;\n",
|
||
" --sklearn-color-line: gray;\n",
|
||
" /* Definition of color scheme for unfitted estimators */\n",
|
||
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
" /* Definition of color scheme for fitted estimators */\n",
|
||
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
"\n",
|
||
" /* Specific color for light theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-icon: #696969;\n",
|
||
"\n",
|
||
" @media (prefers-color-scheme: dark) {\n",
|
||
" /* Redefinition of color scheme for dark theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-icon: #878787;\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 pre {\n",
|
||
" padding: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-hidden--visually {\n",
|
||
" border: 0;\n",
|
||
" clip: rect(1px 1px 1px 1px);\n",
|
||
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
" height: 1px;\n",
|
||
" margin: -1px;\n",
|
||
" overflow: hidden;\n",
|
||
" padding: 0;\n",
|
||
" position: absolute;\n",
|
||
" width: 1px;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
|
||
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" padding-bottom: 0.4em;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-container {\n",
|
||
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
" so we also need the `!important` here to be able to override the\n",
|
||
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
" display: inline-block !important;\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
|
||
" display: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-parallel-item,\n",
|
||
"div.sk-serial,\n",
|
||
"div.sk-item {\n",
|
||
" /* draw centered vertical line to link estimators */\n",
|
||
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
" background-size: 2px 100%;\n",
|
||
" background-repeat: no-repeat;\n",
|
||
" background-position: center center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Parallel-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item::after {\n",
|
||
" content: \"\";\n",
|
||
" width: 100%;\n",
|
||
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
" flex-grow: 1;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel {\n",
|
||
" display: flex;\n",
|
||
" align-items: stretch;\n",
|
||
" justify-content: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
|
||
" align-self: flex-end;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
|
||
" align-self: flex-start;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
|
||
" width: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Serial-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-serial {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
" align-items: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" padding-right: 1em;\n",
|
||
" padding-left: 1em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"\n",
|
||
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
"clickable and can be expanded/collapsed.\n",
|
||
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
"*/\n",
|
||
"\n",
|
||
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable {\n",
|
||
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable label */\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label {\n",
|
||
" cursor: pointer;\n",
|
||
" display: block;\n",
|
||
" width: 100%;\n",
|
||
" margin-bottom: 0;\n",
|
||
" padding: 0.5em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
|
||
" /* Arrow on the left of the label */\n",
|
||
" content: \"▸\";\n",
|
||
" float: left;\n",
|
||
" margin-right: 0.25em;\n",
|
||
" color: var(--sklearn-color-icon);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable content - dropdown */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content {\n",
|
||
" max-height: 0;\n",
|
||
" max-width: 0;\n",
|
||
" overflow: hidden;\n",
|
||
" text-align: left;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
|
||
" margin: 0.2em;\n",
|
||
" border-radius: 0.25em;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
" /* Expand drop-down */\n",
|
||
" max-height: 200px;\n",
|
||
" max-width: 100%;\n",
|
||
" overflow: auto;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
" content: \"▾\";\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific style */\n",
|
||
"\n",
|
||
"/* Colorize estimator box */\n",
|
||
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
|
||
"#sk-container-id-1 div.sk-label label {\n",
|
||
" /* The background is the default theme color */\n",
|
||
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover, darken the color of the background */\n",
|
||
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Label box, darken color on hover, fitted */\n",
|
||
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator label */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label label {\n",
|
||
" font-family: monospace;\n",
|
||
" font-weight: bold;\n",
|
||
" display: inline-block;\n",
|
||
" line-height: 1.2em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label-container {\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific */\n",
|
||
"#sk-container-id-1 div.sk-estimator {\n",
|
||
" font-family: monospace;\n",
|
||
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
" border-radius: 0.25em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" margin-bottom: 0.5em;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* on hover */\n",
|
||
"#sk-container-id-1 div.sk-estimator:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
"\n",
|
||
"/* Common style for \"i\" and \"?\" */\n",
|
||
"\n",
|
||
".sk-estimator-doc-link,\n",
|
||
"a:link.sk-estimator-doc-link,\n",
|
||
"a:visited.sk-estimator-doc-link {\n",
|
||
" float: right;\n",
|
||
" font-size: smaller;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1em;\n",
|
||
" height: 1em;\n",
|
||
" width: 1em;\n",
|
||
" text-decoration: none !important;\n",
|
||
" margin-left: 1ex;\n",
|
||
" /* unfitted */\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted,\n",
|
||
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
".sk-estimator-doc-link span {\n",
|
||
" display: none;\n",
|
||
" z-index: 9999;\n",
|
||
" position: relative;\n",
|
||
" font-weight: normal;\n",
|
||
" right: .2ex;\n",
|
||
" padding: .5ex;\n",
|
||
" margin: .5ex;\n",
|
||
" width: min-content;\n",
|
||
" min-width: 20ex;\n",
|
||
" max-width: 50ex;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
" /* unfitted */\n",
|
||
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted span {\n",
|
||
" /* fitted */\n",
|
||
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link:hover span {\n",
|
||
" display: block;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link {\n",
|
||
" float: right;\n",
|
||
" font-size: 1rem;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1rem;\n",
|
||
" height: 1rem;\n",
|
||
" width: 1rem;\n",
|
||
" text-decoration: none;\n",
|
||
" /* unfitted */\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('classifier',\n",
|
||
" SGDClassifier(loss='log_loss',\n",
|
||
" max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'classifier__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'classifier__eta0': [0.01, 0.1],\n",
|
||
" 'classifier__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='accuracy')</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> GridSearchCV<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('classifier',\n",
|
||
" SGDClassifier(loss='log_loss',\n",
|
||
" max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'classifier__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'classifier__eta0': [0.01, 0.1],\n",
|
||
" 'classifier__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='accuracy')</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">best_estimator_: Pipeline</label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num', StandardScaler(),\n",
|
||
" ['age', 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender', 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('classifier',\n",
|
||
" SGDClassifier(alpha=0.001, eta0=0.01, learning_rate='constant',\n",
|
||
" loss='log_loss', max_iter=2000,\n",
|
||
" random_state=42))])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> preprocessor: ColumnTransformer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html\">?<span>Documentation for preprocessor: ColumnTransformer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>ColumnTransformer(transformers=[('num', StandardScaler(),\n",
|
||
" ['age', 'avg_glucose_level', 'bmi']),\n",
|
||
" ('cat', OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender', 'ever_married', 'smoking_status'])])</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">num</label><div class=\"sk-toggleable__content fitted\"><pre>['age', 'avg_glucose_level', 'bmi']</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">cat</label><div class=\"sk-toggleable__content fitted\"><pre>['gender', 'ever_married', 'smoking_status']</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> OneHotEncoder<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.OneHotEncoder.html\">?<span>Documentation for OneHotEncoder</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div> </div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> SGDClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.SGDClassifier.html\">?<span>Documentation for SGDClassifier</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>SGDClassifier(alpha=0.001, eta0=0.01, learning_rate='constant', loss='log_loss',\n",
|
||
" max_iter=2000, random_state=42)</pre></div> </div></div></div></div></div></div></div></div></div></div></div>"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T15:40:51.508706Z",
|
||
"start_time": "2024-12-24T15:40:51.415577Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Оценка моделей\n",
|
||
"y_class_pred_rf = grid_search_class_rf.predict(X_test)\n",
|
||
"y_class_pred_sgd = grid_search_class_sgd.predict(X_test)\n",
|
||
"\n",
|
||
"print(\"Classification Report for Random Forest:\")\n",
|
||
"print(classification_report(y_class_test, y_class_pred_rf))\n",
|
||
"print(\"Confusion Matrix for Random Forest:\")\n",
|
||
"print(confusion_matrix(y_class_test, y_class_pred_rf))\n",
|
||
"\n",
|
||
"print(\"Classification Report for SGD:\")\n",
|
||
"print(classification_report(y_class_test, y_class_pred_sgd))\n",
|
||
"print(\"Confusion Matrix for SGD:\")\n",
|
||
"print(confusion_matrix(y_class_test, y_class_pred_sgd))"
|
||
],
|
||
"id": "20d5d13f359c1c48",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Classification Report for Random Forest:\n",
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 1.00 0.97 0.98 993\n",
|
||
" 1 0.97 1.00 0.98 952\n",
|
||
"\n",
|
||
" accuracy 0.98 1945\n",
|
||
" macro avg 0.98 0.98 0.98 1945\n",
|
||
"weighted avg 0.99 0.98 0.98 1945\n",
|
||
"\n",
|
||
"Confusion Matrix for Random Forest:\n",
|
||
"[[963 30]\n",
|
||
" [ 0 952]]\n",
|
||
"Classification Report for SGD:\n",
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 0.84 0.73 0.78 993\n",
|
||
" 1 0.75 0.85 0.80 952\n",
|
||
"\n",
|
||
" accuracy 0.79 1945\n",
|
||
" macro avg 0.80 0.79 0.79 1945\n",
|
||
"weighted avg 0.80 0.79 0.79 1945\n",
|
||
"\n",
|
||
"Confusion Matrix for SGD:\n",
|
||
"[[728 265]\n",
|
||
" [141 811]]\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 4
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T15:58:49.129025Z",
|
||
"start_time": "2024-12-24T15:40:51.540498Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Оценка смещения и дисперсии\n",
|
||
"bias_class_rf, variance_class_rf = estimate_bias_variance(grid_search_class_rf.best_estimator_, X_train, y_class_train)\n",
|
||
"bias_class_sgd, variance_class_sgd = estimate_bias_variance(grid_search_class_sgd.best_estimator_, X_train, y_class_train)"
|
||
],
|
||
"id": "5a9e5c5b7073a391",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Classification Bias (Random Forest): 0.0\n",
|
||
"Classification Variance (Random Forest): 0.0\n",
|
||
"Classification Bias (SGD): 0.229008615147229\n",
|
||
"Classification Variance (SGD): 0.0\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T18:05:29.863734Z",
|
||
"start_time": "2024-12-24T18:05:29.842855Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from sklearn.utils import resample\n",
|
||
"import pandas as pd\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from sklearn.preprocessing import LabelEncoder\n",
|
||
"from sklearn import metrics\n",
|
||
"from imblearn.over_sampling import RandomOverSampler\n",
|
||
"from imblearn.under_sampling import RandomUnderSampler\n",
|
||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||
"from sklearn.metrics import ConfusionMatrixDisplay\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"from sklearn.pipeline import Pipeline\n",
|
||
"from sklearn.impute import SimpleImputer\n",
|
||
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
|
||
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
|
||
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
|
||
"from sklearn.linear_model import SGDClassifier, SGDRegressor\n",
|
||
"from sklearn.metrics import (\n",
|
||
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
|
||
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
|
||
")\n",
|
||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
|
||
"import numpy as np\n",
|
||
"import featuretools as ft\n",
|
||
"from sklearn.metrics import accuracy_score, classification_report\n",
|
||
"\n",
|
||
"df = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
|
||
"\n",
|
||
"# Обработка пропущенных значений\n",
|
||
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
|
||
"\n",
|
||
"# Удаление или замена неизвестных категорий\n",
|
||
"# Например, замена 'Other' на 'Unknown' в столбце 'gender'\n",
|
||
"df['gender'] = df['gender'].replace('Other', 'Unknown')\n",
|
||
"\n",
|
||
"# Разделяем классы\n",
|
||
"stroke_0 = df[df['stroke'] == 0]\n",
|
||
"stroke_1 = df[df['stroke'] == 1]\n",
|
||
"\n",
|
||
"# Увеличиваем выборку для 1\n",
|
||
"stroke_1 = resample(stroke_1, replace=True, n_samples=len(stroke_0), random_state=42)\n",
|
||
"\n",
|
||
"# Объединяем классы\n",
|
||
"df = pd.concat([stroke_0, stroke_1])\n",
|
||
"\n",
|
||
"df = df.sample(frac=0.4)\n",
|
||
"\n",
|
||
"df.info()"
|
||
],
|
||
"id": "d2394bb97604d2a1",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"Index: 3889 entries, 4615 to 2797\n",
|
||
"Data columns (total 12 columns):\n",
|
||
" # Column Non-Null Count Dtype \n",
|
||
"--- ------ -------------- ----- \n",
|
||
" 0 id 3889 non-null int64 \n",
|
||
" 1 gender 3889 non-null object \n",
|
||
" 2 age 3889 non-null float64\n",
|
||
" 3 hypertension 3889 non-null int64 \n",
|
||
" 4 heart_disease 3889 non-null int64 \n",
|
||
" 5 ever_married 3889 non-null object \n",
|
||
" 6 work_type 3889 non-null object \n",
|
||
" 7 Residence_type 3889 non-null object \n",
|
||
" 8 avg_glucose_level 3889 non-null float64\n",
|
||
" 9 bmi 3889 non-null float64\n",
|
||
" 10 smoking_status 3889 non-null object \n",
|
||
" 11 stroke 3889 non-null int64 \n",
|
||
"dtypes: float64(3), int64(4), object(5)\n",
|
||
"memory usage: 395.0+ KB\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 5
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T18:05:32.603683Z",
|
||
"start_time": "2024-12-24T18:05:32.596873Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Определение целевых переменных\n",
|
||
"X = df.drop('stroke', axis=1)\n",
|
||
"y_class = df['stroke'] # Задача классификации\n",
|
||
"y_reg = df['avg_glucose_level'] # Задача регрессии\n",
|
||
"\n",
|
||
"# Преобразование категориальных переменных\n",
|
||
"categorical_features = ['gender', 'ever_married', 'smoking_status']\n",
|
||
"numerical_features = ['age', 'avg_glucose_level', 'bmi']\n",
|
||
"\n",
|
||
"# Создание ColumnTransformer с обработкой неизвестных категорий\n",
|
||
"preprocessor = ColumnTransformer(\n",
|
||
" transformers=[\n",
|
||
" ('num', StandardScaler(), numerical_features),\n",
|
||
" ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)]) # Используем handle_unknown='ignore'\n",
|
||
"\n",
|
||
"# Разделение данных на обучающую и тестовую выборки\n",
|
||
"X_train, X_test, y_class_train, y_class_test, y_reg_train, y_reg_test = train_test_split(X, y_class, y_reg, test_size=0.2, random_state=42) \n",
|
||
"\n",
|
||
"def estimate_bias_variance(model, X, y):\n",
|
||
" predictions = np.array([model.fit(X, y).predict(X) for _ in range(1000)])\n",
|
||
" bias = np.mean((y - np.mean(predictions, axis=0)) ** 2)\n",
|
||
" variance = np.mean(np.var(predictions, axis=0))\n",
|
||
" return bias, variance"
|
||
],
|
||
"id": "45df888b49839959",
|
||
"outputs": [],
|
||
"execution_count": 6
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T18:06:01.805307Z",
|
||
"start_time": "2024-12-24T18:05:38.651793Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Задача регрессии\n",
|
||
"reg_pipeline_rf = Pipeline(steps=[\n",
|
||
" ('preprocessor', preprocessor),\n",
|
||
" ('regressor', RandomForestRegressor(random_state=42))])\n",
|
||
"\n",
|
||
"reg_pipeline_sgd = Pipeline(steps=[\n",
|
||
" ('preprocessor', preprocessor),\n",
|
||
" ('regressor', SGDRegressor(loss='squared_error', penalty='l2', random_state=42, max_iter=2000))])\n",
|
||
"\n",
|
||
"# Настройка гиперпараметров для регрессии\n",
|
||
"param_grid_reg_rf = {\n",
|
||
" 'regressor__n_estimators': [100, 200],\n",
|
||
" 'regressor__max_depth': [None, 10, 20]}\n",
|
||
"\n",
|
||
"param_grid_reg_sgd = {\n",
|
||
" 'regressor__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'regressor__learning_rate': ['constant', 'adaptive'],\n",
|
||
" 'regressor__eta0': [0.01, 0.1]}\n",
|
||
"\n",
|
||
"# Поиск гиперпараметров\n",
|
||
"grid_search_reg_rf = GridSearchCV(reg_pipeline_rf, param_grid_reg_rf, cv=5, scoring='r2')\n",
|
||
"grid_search_reg_rf.fit(X_train, y_reg_train)\n",
|
||
"\n",
|
||
"grid_search_reg_sgd = GridSearchCV(reg_pipeline_sgd, param_grid_reg_sgd, cv=5, scoring='r2')\n",
|
||
"grid_search_reg_sgd.fit(X_train, y_reg_train)"
|
||
],
|
||
"id": "ee19135bbfc42564",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('regressor',\n",
|
||
" SGDRegressor(max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'regressor__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'regressor__eta0': [0.01, 0.1],\n",
|
||
" 'regressor__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='r2')"
|
||
],
|
||
"text/html": [
|
||
"<style>#sk-container-id-1 {\n",
|
||
" /* Definition of color scheme common for light and dark mode */\n",
|
||
" --sklearn-color-text: black;\n",
|
||
" --sklearn-color-line: gray;\n",
|
||
" /* Definition of color scheme for unfitted estimators */\n",
|
||
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
|
||
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
|
||
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
|
||
" --sklearn-color-unfitted-level-3: chocolate;\n",
|
||
" /* Definition of color scheme for fitted estimators */\n",
|
||
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
|
||
" --sklearn-color-fitted-level-1: #d4ebff;\n",
|
||
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
|
||
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
|
||
"\n",
|
||
" /* Specific color for light theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
|
||
" --sklearn-color-icon: #696969;\n",
|
||
"\n",
|
||
" @media (prefers-color-scheme: dark) {\n",
|
||
" /* Redefinition of color scheme for dark theme */\n",
|
||
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
|
||
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
|
||
" --sklearn-color-icon: #878787;\n",
|
||
" }\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 pre {\n",
|
||
" padding: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-hidden--visually {\n",
|
||
" border: 0;\n",
|
||
" clip: rect(1px 1px 1px 1px);\n",
|
||
" clip: rect(1px, 1px, 1px, 1px);\n",
|
||
" height: 1px;\n",
|
||
" margin: -1px;\n",
|
||
" overflow: hidden;\n",
|
||
" padding: 0;\n",
|
||
" position: absolute;\n",
|
||
" width: 1px;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
|
||
" border: 1px dashed var(--sklearn-color-line);\n",
|
||
" margin: 0 0.4em 0.5em 0.4em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" padding-bottom: 0.4em;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-container {\n",
|
||
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
|
||
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
|
||
" so we also need the `!important` here to be able to override the\n",
|
||
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
|
||
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
|
||
" display: inline-block !important;\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
|
||
" display: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-parallel-item,\n",
|
||
"div.sk-serial,\n",
|
||
"div.sk-item {\n",
|
||
" /* draw centered vertical line to link estimators */\n",
|
||
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
|
||
" background-size: 2px 100%;\n",
|
||
" background-repeat: no-repeat;\n",
|
||
" background-position: center center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Parallel-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item::after {\n",
|
||
" content: \"\";\n",
|
||
" width: 100%;\n",
|
||
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
|
||
" flex-grow: 1;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel {\n",
|
||
" display: flex;\n",
|
||
" align-items: stretch;\n",
|
||
" justify-content: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" position: relative;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
|
||
" align-self: flex-end;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
|
||
" align-self: flex-start;\n",
|
||
" width: 50%;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
|
||
" width: 0;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Serial-specific style estimator block */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-serial {\n",
|
||
" display: flex;\n",
|
||
" flex-direction: column;\n",
|
||
" align-items: center;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" padding-right: 1em;\n",
|
||
" padding-left: 1em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"\n",
|
||
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
|
||
"clickable and can be expanded/collapsed.\n",
|
||
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
|
||
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
|
||
"*/\n",
|
||
"\n",
|
||
"/* Pipeline and ColumnTransformer style (default) */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable {\n",
|
||
" /* Default theme specific background. It is overwritten whether we have a\n",
|
||
" specific estimator or a Pipeline/ColumnTransformer */\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable label */\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label {\n",
|
||
" cursor: pointer;\n",
|
||
" display: block;\n",
|
||
" width: 100%;\n",
|
||
" margin-bottom: 0;\n",
|
||
" padding: 0.5em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
|
||
" /* Arrow on the left of the label */\n",
|
||
" content: \"▸\";\n",
|
||
" float: left;\n",
|
||
" margin-right: 0.25em;\n",
|
||
" color: var(--sklearn-color-icon);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Toggleable content - dropdown */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content {\n",
|
||
" max-height: 0;\n",
|
||
" max-width: 0;\n",
|
||
" overflow: hidden;\n",
|
||
" text-align: left;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
|
||
" margin: 0.2em;\n",
|
||
" border-radius: 0.25em;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
|
||
" /* Expand drop-down */\n",
|
||
" max-height: 200px;\n",
|
||
" max-width: 100%;\n",
|
||
" overflow: auto;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
|
||
" content: \"▾\";\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Pipeline/ColumnTransformer-specific style */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific style */\n",
|
||
"\n",
|
||
"/* Colorize estimator box */\n",
|
||
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
|
||
"#sk-container-id-1 div.sk-label label {\n",
|
||
" /* The background is the default theme color */\n",
|
||
" color: var(--sklearn-color-text-on-default-background);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover, darken the color of the background */\n",
|
||
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Label box, darken color on hover, fitted */\n",
|
||
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator label */\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label label {\n",
|
||
" font-family: monospace;\n",
|
||
" font-weight: bold;\n",
|
||
" display: inline-block;\n",
|
||
" line-height: 1.2em;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-label-container {\n",
|
||
" text-align: center;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Estimator-specific */\n",
|
||
"#sk-container-id-1 div.sk-estimator {\n",
|
||
" font-family: monospace;\n",
|
||
" border: 1px dotted var(--sklearn-color-border-box);\n",
|
||
" border-radius: 0.25em;\n",
|
||
" box-sizing: border-box;\n",
|
||
" margin-bottom: 0.5em;\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-0);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* on hover */\n",
|
||
"#sk-container-id-1 div.sk-estimator:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-2);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
|
||
"\n",
|
||
"/* Common style for \"i\" and \"?\" */\n",
|
||
"\n",
|
||
".sk-estimator-doc-link,\n",
|
||
"a:link.sk-estimator-doc-link,\n",
|
||
"a:visited.sk-estimator-doc-link {\n",
|
||
" float: right;\n",
|
||
" font-size: smaller;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1em;\n",
|
||
" height: 1em;\n",
|
||
" width: 1em;\n",
|
||
" text-decoration: none !important;\n",
|
||
" margin-left: 1ex;\n",
|
||
" /* unfitted */\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted,\n",
|
||
"a:link.sk-estimator-doc-link.fitted,\n",
|
||
"a:visited.sk-estimator-doc-link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
|
||
".sk-estimator-doc-link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover,\n",
|
||
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
|
||
".sk-estimator-doc-link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* Span, style for the box shown on hovering the info icon */\n",
|
||
".sk-estimator-doc-link span {\n",
|
||
" display: none;\n",
|
||
" z-index: 9999;\n",
|
||
" position: relative;\n",
|
||
" font-weight: normal;\n",
|
||
" right: .2ex;\n",
|
||
" padding: .5ex;\n",
|
||
" margin: .5ex;\n",
|
||
" width: min-content;\n",
|
||
" min-width: 20ex;\n",
|
||
" max-width: 50ex;\n",
|
||
" color: var(--sklearn-color-text);\n",
|
||
" box-shadow: 2pt 2pt 4pt #999;\n",
|
||
" /* unfitted */\n",
|
||
" background: var(--sklearn-color-unfitted-level-0);\n",
|
||
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link.fitted span {\n",
|
||
" /* fitted */\n",
|
||
" background: var(--sklearn-color-fitted-level-0);\n",
|
||
" border: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"\n",
|
||
".sk-estimator-doc-link:hover span {\n",
|
||
" display: block;\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link {\n",
|
||
" float: right;\n",
|
||
" font-size: 1rem;\n",
|
||
" line-height: 1em;\n",
|
||
" font-family: monospace;\n",
|
||
" background-color: var(--sklearn-color-background);\n",
|
||
" border-radius: 1rem;\n",
|
||
" height: 1rem;\n",
|
||
" width: 1rem;\n",
|
||
" text-decoration: none;\n",
|
||
" /* unfitted */\n",
|
||
" color: var(--sklearn-color-unfitted-level-1);\n",
|
||
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
|
||
" /* fitted */\n",
|
||
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
|
||
" color: var(--sklearn-color-fitted-level-1);\n",
|
||
"}\n",
|
||
"\n",
|
||
"/* On hover */\n",
|
||
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
|
||
" /* unfitted */\n",
|
||
" background-color: var(--sklearn-color-unfitted-level-3);\n",
|
||
" color: var(--sklearn-color-background);\n",
|
||
" text-decoration: none;\n",
|
||
"}\n",
|
||
"\n",
|
||
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
|
||
" /* fitted */\n",
|
||
" background-color: var(--sklearn-color-fitted-level-3);\n",
|
||
"}\n",
|
||
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('regressor',\n",
|
||
" SGDRegressor(max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'regressor__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'regressor__eta0': [0.01, 0.1],\n",
|
||
" 'regressor__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='r2')</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> GridSearchCV<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=5,\n",
|
||
" estimator=Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num',\n",
|
||
" StandardScaler(),\n",
|
||
" ['age',\n",
|
||
" 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender',\n",
|
||
" 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('regressor',\n",
|
||
" SGDRegressor(max_iter=2000,\n",
|
||
" random_state=42))]),\n",
|
||
" param_grid={'regressor__alpha': [0.0001, 0.001, 0.01],\n",
|
||
" 'regressor__eta0': [0.01, 0.1],\n",
|
||
" 'regressor__learning_rate': ['constant', 'adaptive']},\n",
|
||
" scoring='r2')</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">best_estimator_: Pipeline</label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[('preprocessor',\n",
|
||
" ColumnTransformer(transformers=[('num', StandardScaler(),\n",
|
||
" ['age', 'avg_glucose_level',\n",
|
||
" 'bmi']),\n",
|
||
" ('cat',\n",
|
||
" OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender', 'ever_married',\n",
|
||
" 'smoking_status'])])),\n",
|
||
" ('regressor',\n",
|
||
" SGDRegressor(eta0=0.1, learning_rate='adaptive', max_iter=2000,\n",
|
||
" random_state=42))])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> preprocessor: ColumnTransformer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html\">?<span>Documentation for preprocessor: ColumnTransformer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>ColumnTransformer(transformers=[('num', StandardScaler(),\n",
|
||
" ['age', 'avg_glucose_level', 'bmi']),\n",
|
||
" ('cat', OneHotEncoder(handle_unknown='ignore'),\n",
|
||
" ['gender', 'ever_married', 'smoking_status'])])</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">num</label><div class=\"sk-toggleable__content fitted\"><pre>['age', 'avg_glucose_level', 'bmi']</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">cat</label><div class=\"sk-toggleable__content fitted\"><pre>['gender', 'ever_married', 'smoking_status']</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> OneHotEncoder<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.OneHotEncoder.html\">?<span>Documentation for OneHotEncoder</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>OneHotEncoder(handle_unknown='ignore')</pre></div> </div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> SGDRegressor<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.SGDRegressor.html\">?<span>Documentation for SGDRegressor</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>SGDRegressor(eta0=0.1, learning_rate='adaptive', max_iter=2000, random_state=42)</pre></div> </div></div></div></div></div></div></div></div></div></div></div>"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 7
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T18:06:06.657058Z",
|
||
"start_time": "2024-12-24T18:06:06.621503Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Оценка моделей\n",
|
||
"y_reg_pred_rf = grid_search_reg_rf.predict(X_test)\n",
|
||
"y_reg_pred_sgd = grid_search_reg_sgd.predict(X_test)\n",
|
||
"\n",
|
||
"print(\"Regression Metrics for Random Forest:\")\n",
|
||
"print(\"Mean Squared Error:\", mean_squared_error(y_reg_test, y_reg_pred_rf))\n",
|
||
"print(\"R2 Score:\", r2_score(y_reg_test, y_reg_pred_rf))\n",
|
||
"\n",
|
||
"print(\"Regression Metrics for SGD:\")\n",
|
||
"print(\"Mean Squared Error:\", mean_squared_error(y_reg_test, y_reg_pred_sgd))\n",
|
||
"print(\"R2 Score:\", r2_score(y_reg_test, y_reg_pred_sgd))"
|
||
],
|
||
"id": "65f8397fab8911ba",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Regression Metrics for Random Forest:\n",
|
||
"Mean Squared Error: 0.020541479842551013\n",
|
||
"R2 Score: 0.9999932939237861\n",
|
||
"Regression Metrics for SGD:\n",
|
||
"Mean Squared Error: 4.648382893338219e-05\n",
|
||
"R2 Score: 0.9999999848246522\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 8
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2024-12-24T18:28:20.340519Z",
|
||
"start_time": "2024-12-24T18:06:12.807767Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Оценка смещения и дисперсии\n",
|
||
"bias_reg_rf, variance_reg_rf = estimate_bias_variance(grid_search_reg_rf.best_estimator_, X_train, y_reg_train)\n",
|
||
"bias_reg_sgd, variance_reg_sgd = estimate_bias_variance(grid_search_reg_sgd.best_estimator_, X_train, y_reg_train)\n",
|
||
"\n",
|
||
"print(\"Regression Bias (Random Forest):\", bias_reg_rf)\n",
|
||
"print(\"Regression Variance (Random Forest):\", variance_reg_rf)\n",
|
||
"print(\"Regression Bias (SGD):\", bias_reg_sgd)\n",
|
||
"print(\"Regression Variance (SGD):\", variance_reg_sgd)"
|
||
],
|
||
"id": "cccd002e6275411f",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Regression Bias (Random Forest): 0.0022602929210874885\n",
|
||
"Regression Variance (Random Forest): 3.608883047891326e-24\n",
|
||
"Regression Bias (SGD): 4.682701837803326e-05\n",
|
||
"Regression Variance (SGD): 3.2443460449162085e-24\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 9
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|