AIM-PIbd-32-Borovkov-M-V/Lab_4/lab4.ipynb
2024-12-24 22:38:45 +04:00

1500 lines
76 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T15:40:10.929503Z",
"start_time": "2024-12-24T15:40:08.889350Z"
}
},
"cell_type": "code",
"source": [
"from sklearn.utils import resample\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn import metrics\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.linear_model import SGDClassifier, SGDRegressor\n",
"from sklearn.metrics import (\n",
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
")\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
"import numpy as np\n",
"import featuretools as ft\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"df = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
"\n",
"# Обработка пропущенных значений\n",
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"# Удаление или замена неизвестных категорий\n",
"# Например, замена 'Other' на 'Unknown' в столбце 'gender'\n",
"df['gender'] = df['gender'].replace('Other', 'Unknown')\n",
"\n",
"# Разделяем классы\n",
"stroke_0 = df[df['stroke'] == 0]\n",
"stroke_1 = df[df['stroke'] == 1]\n",
"\n",
"# Увеличиваем выборку для 1\n",
"stroke_1 = resample(stroke_1, replace=True, n_samples=len(stroke_0), random_state=42)\n",
"\n",
"# Объединяем классы\n",
"df = pd.concat([stroke_0, stroke_1])\n",
"\n",
"df.info()"
],
"id": "db3cd8711d083df",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 9722 entries, 249 to 134\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 id 9722 non-null int64 \n",
" 1 gender 9722 non-null object \n",
" 2 age 9722 non-null float64\n",
" 3 hypertension 9722 non-null int64 \n",
" 4 heart_disease 9722 non-null int64 \n",
" 5 ever_married 9722 non-null object \n",
" 6 work_type 9722 non-null object \n",
" 7 Residence_type 9722 non-null object \n",
" 8 avg_glucose_level 9722 non-null float64\n",
" 9 bmi 9722 non-null float64\n",
" 10 smoking_status 9722 non-null object \n",
" 11 stroke 9722 non-null int64 \n",
"dtypes: float64(3), int64(4), object(5)\n",
"memory usage: 987.4+ KB\n"
]
}
],
"execution_count": 1
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"Классификация: Предсказать вероятность инсульта на основе данных пациента.\n",
"\n",
"Регрессия: Предсказать уровень глюкозы в крови на основе данных пациента."
],
"id": "d5f640ac158b69c5"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T15:40:11.000561Z",
"start_time": "2024-12-24T15:40:10.979999Z"
}
},
"cell_type": "code",
"source": [
"# Определение целевых переменных\n",
"X = df.drop('stroke', axis=1)\n",
"y_class = df['stroke'] # Задача классификации\n",
"y_reg = df['avg_glucose_level'] # Задача регрессии\n",
"\n",
"# Преобразование категориальных переменных\n",
"categorical_features = ['gender', 'ever_married', 'smoking_status']\n",
"numerical_features = ['age', 'avg_glucose_level', 'bmi']\n",
"\n",
"# Создание ColumnTransformer с обработкой неизвестных категорий\n",
"preprocessor = ColumnTransformer(\n",
" transformers=[\n",
" ('num', StandardScaler(), numerical_features),\n",
" ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)]) # Используем handle_unknown='ignore'\n",
"\n",
"# Разделение данных на обучающую и тестовую выборки\n",
"X_train, X_test, y_class_train, y_class_test, y_reg_train, y_reg_test = train_test_split(X, y_class, y_reg, test_size=0.2, random_state=42) \n",
"\n",
"def estimate_bias_variance(model, X, y):\n",
" predictions = np.array([model.fit(X, y).predict(X) for _ in range(1000)])\n",
" bias = np.mean((y - np.mean(predictions, axis=0)) ** 2)\n",
" variance = np.mean(np.var(predictions, axis=0))\n",
" return bias, variance"
],
"id": "55c612d9a4b90a55",
"outputs": [],
"execution_count": 2
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Классификация",
"id": "273d7304e338f532"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T15:40:50.317809Z",
"start_time": "2024-12-24T15:40:11.011074Z"
}
},
"cell_type": "code",
"source": [
"# Задача классификации\n",
"class_pipeline_rf = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('classifier', RandomForestClassifier(random_state=42))])\n",
"\n",
"class_pipeline_sgd = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('classifier', SGDClassifier(loss='log_loss', penalty='l2', random_state=42, max_iter=2000))]) \n",
"\n",
"# Настройка гиперпараметров\n",
"param_grid_class_rf = {\n",
" 'classifier__n_estimators': [100, 200],\n",
" 'classifier__max_depth': [None, 10, 20]}\n",
"\n",
"param_grid_class_sgd = {\n",
" 'classifier__alpha': [0.0001, 0.001, 0.01],\n",
" 'classifier__learning_rate': ['constant', 'adaptive'],\n",
" 'classifier__eta0': [0.01, 0.1]}\n",
"\n",
"# Поиск гиперпараметров\n",
"grid_search_class_rf = GridSearchCV(class_pipeline_rf, param_grid_class_rf, cv=5, scoring='accuracy')\n",
"grid_search_class_rf.fit(X_train, y_class_train)\n",
"\n",
"grid_search_class_sgd = GridSearchCV(class_pipeline_sgd, param_grid_class_sgd, cv=5, scoring='accuracy')\n",
"grid_search_class_sgd.fit(X_train, y_class_train)"
],
"id": "956c94b392572508",
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[('preprocessor',\n",
" ColumnTransformer(transformers=[('num',\n",
" StandardScaler(),\n",
" ['age',\n",
" 'avg_glucose_level',\n",
" 'bmi']),\n",
" ('cat',\n",
" OneHotEncoder(handle_unknown='ignore'),\n",
" ['gender',\n",
" 'ever_married',\n",
" 'smoking_status'])])),\n",
" ('classifier',\n",
" SGDClassifier(loss='log_loss',\n",
" max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={'classifier__alpha': [0.0001, 0.001, 0.01],\n",
" 'classifier__eta0': [0.01, 0.1],\n",
" 'classifier__learning_rate': ['constant', 'adaptive']},\n",
" scoring='accuracy')"
],
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;,\n",
" StandardScaler(),\n",
" [&#x27;age&#x27;,\n",
" &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;,\n",
" &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;classifier&#x27;,\n",
" SGDClassifier(loss=&#x27;log_loss&#x27;,\n",
" max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={&#x27;classifier__alpha&#x27;: [0.0001, 0.001, 0.01],\n",
" &#x27;classifier__eta0&#x27;: [0.01, 0.1],\n",
" &#x27;classifier__learning_rate&#x27;: [&#x27;constant&#x27;, &#x27;adaptive&#x27;]},\n",
" scoring=&#x27;accuracy&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;GridSearchCV<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;,\n",
" StandardScaler(),\n",
" [&#x27;age&#x27;,\n",
" &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;,\n",
" &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;classifier&#x27;,\n",
" SGDClassifier(loss=&#x27;log_loss&#x27;,\n",
" max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={&#x27;classifier__alpha&#x27;: [0.0001, 0.001, 0.01],\n",
" &#x27;classifier__eta0&#x27;: [0.01, 0.1],\n",
" &#x27;classifier__learning_rate&#x27;: [&#x27;constant&#x27;, &#x27;adaptive&#x27;]},\n",
" scoring=&#x27;accuracy&#x27;)</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">best_estimator_: Pipeline</label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;, StandardScaler(),\n",
" [&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;, &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;classifier&#x27;,\n",
" SGDClassifier(alpha=0.001, eta0=0.01, learning_rate=&#x27;constant&#x27;,\n",
" loss=&#x27;log_loss&#x27;, max_iter=2000,\n",
" random_state=42))])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;preprocessor: ColumnTransformer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html\">?<span>Documentation for preprocessor: ColumnTransformer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>ColumnTransformer(transformers=[(&#x27;num&#x27;, StandardScaler(),\n",
" [&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;, &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;, OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;, &#x27;ever_married&#x27;, &#x27;smoking_status&#x27;])])</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">num</label><div class=\"sk-toggleable__content fitted\"><pre>[&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;, &#x27;bmi&#x27;]</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">cat</label><div class=\"sk-toggleable__content fitted\"><pre>[&#x27;gender&#x27;, &#x27;ever_married&#x27;, &#x27;smoking_status&#x27;]</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;OneHotEncoder<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.OneHotEncoder.html\">?<span>Documentation for OneHotEncoder</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>OneHotEncoder(handle_unknown=&#x27;ignore&#x27;)</pre></div> </div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;SGDClassifier<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.SGDClassifier.html\">?<span>Documentation for SGDClassifier</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>SGDClassifier(alpha=0.001, eta0=0.01, learning_rate=&#x27;constant&#x27;, loss=&#x27;log_loss&#x27;,\n",
" max_iter=2000, random_state=42)</pre></div> </div></div></div></div></div></div></div></div></div></div></div>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T15:40:51.508706Z",
"start_time": "2024-12-24T15:40:51.415577Z"
}
},
"cell_type": "code",
"source": [
"# Оценка моделей\n",
"y_class_pred_rf = grid_search_class_rf.predict(X_test)\n",
"y_class_pred_sgd = grid_search_class_sgd.predict(X_test)\n",
"\n",
"print(\"Classification Report for Random Forest:\")\n",
"print(classification_report(y_class_test, y_class_pred_rf))\n",
"print(\"Confusion Matrix for Random Forest:\")\n",
"print(confusion_matrix(y_class_test, y_class_pred_rf))\n",
"\n",
"print(\"Classification Report for SGD:\")\n",
"print(classification_report(y_class_test, y_class_pred_sgd))\n",
"print(\"Confusion Matrix for SGD:\")\n",
"print(confusion_matrix(y_class_test, y_class_pred_sgd))"
],
"id": "20d5d13f359c1c48",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification Report for Random Forest:\n",
" precision recall f1-score support\n",
"\n",
" 0 1.00 0.97 0.98 993\n",
" 1 0.97 1.00 0.98 952\n",
"\n",
" accuracy 0.98 1945\n",
" macro avg 0.98 0.98 0.98 1945\n",
"weighted avg 0.99 0.98 0.98 1945\n",
"\n",
"Confusion Matrix for Random Forest:\n",
"[[963 30]\n",
" [ 0 952]]\n",
"Classification Report for SGD:\n",
" precision recall f1-score support\n",
"\n",
" 0 0.84 0.73 0.78 993\n",
" 1 0.75 0.85 0.80 952\n",
"\n",
" accuracy 0.79 1945\n",
" macro avg 0.80 0.79 0.79 1945\n",
"weighted avg 0.80 0.79 0.79 1945\n",
"\n",
"Confusion Matrix for SGD:\n",
"[[728 265]\n",
" [141 811]]\n"
]
}
],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T15:58:49.129025Z",
"start_time": "2024-12-24T15:40:51.540498Z"
}
},
"cell_type": "code",
"source": [
"# Оценка смещения и дисперсии\n",
"bias_class_rf, variance_class_rf = estimate_bias_variance(grid_search_class_rf.best_estimator_, X_train, y_class_train)\n",
"bias_class_sgd, variance_class_sgd = estimate_bias_variance(grid_search_class_sgd.best_estimator_, X_train, y_class_train)"
],
"id": "5a9e5c5b7073a391",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classification Bias (Random Forest): 0.0\n",
"Classification Variance (Random Forest): 0.0\n",
"Classification Bias (SGD): 0.229008615147229\n",
"Classification Variance (SGD): 0.0\n"
]
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T18:05:29.863734Z",
"start_time": "2024-12-24T18:05:29.842855Z"
}
},
"cell_type": "code",
"source": [
"from sklearn.utils import resample\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.preprocessing import LabelEncoder\n",
"from sklearn import metrics\n",
"from imblearn.over_sampling import RandomOverSampler\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.linear_model import LinearRegression, LogisticRegression\n",
"from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier, GradientBoostingClassifier\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.linear_model import SGDClassifier, SGDRegressor\n",
"from sklearn.metrics import (\n",
" precision_score, recall_score, accuracy_score, roc_auc_score, f1_score,\n",
" matthews_corrcoef, cohen_kappa_score, confusion_matrix\n",
")\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
"import numpy as np\n",
"import featuretools as ft\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"df = pd.read_csv(\"healthcare-dataset-stroke-data.csv\")\n",
"\n",
"# Обработка пропущенных значений\n",
"df[\"bmi\"] = df[\"bmi\"].fillna(df[\"bmi\"].median())\n",
"\n",
"# Удаление или замена неизвестных категорий\n",
"# Например, замена 'Other' на 'Unknown' в столбце 'gender'\n",
"df['gender'] = df['gender'].replace('Other', 'Unknown')\n",
"\n",
"# Разделяем классы\n",
"stroke_0 = df[df['stroke'] == 0]\n",
"stroke_1 = df[df['stroke'] == 1]\n",
"\n",
"# Увеличиваем выборку для 1\n",
"stroke_1 = resample(stroke_1, replace=True, n_samples=len(stroke_0), random_state=42)\n",
"\n",
"# Объединяем классы\n",
"df = pd.concat([stroke_0, stroke_1])\n",
"\n",
"df = df.sample(frac=0.4)\n",
"\n",
"df.info()"
],
"id": "d2394bb97604d2a1",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Index: 3889 entries, 4615 to 2797\n",
"Data columns (total 12 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 id 3889 non-null int64 \n",
" 1 gender 3889 non-null object \n",
" 2 age 3889 non-null float64\n",
" 3 hypertension 3889 non-null int64 \n",
" 4 heart_disease 3889 non-null int64 \n",
" 5 ever_married 3889 non-null object \n",
" 6 work_type 3889 non-null object \n",
" 7 Residence_type 3889 non-null object \n",
" 8 avg_glucose_level 3889 non-null float64\n",
" 9 bmi 3889 non-null float64\n",
" 10 smoking_status 3889 non-null object \n",
" 11 stroke 3889 non-null int64 \n",
"dtypes: float64(3), int64(4), object(5)\n",
"memory usage: 395.0+ KB\n"
]
}
],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T18:05:32.603683Z",
"start_time": "2024-12-24T18:05:32.596873Z"
}
},
"cell_type": "code",
"source": [
"# Определение целевых переменных\n",
"X = df.drop('stroke', axis=1)\n",
"y_class = df['stroke'] # Задача классификации\n",
"y_reg = df['avg_glucose_level'] # Задача регрессии\n",
"\n",
"# Преобразование категориальных переменных\n",
"categorical_features = ['gender', 'ever_married', 'smoking_status']\n",
"numerical_features = ['age', 'avg_glucose_level', 'bmi']\n",
"\n",
"# Создание ColumnTransformer с обработкой неизвестных категорий\n",
"preprocessor = ColumnTransformer(\n",
" transformers=[\n",
" ('num', StandardScaler(), numerical_features),\n",
" ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)]) # Используем handle_unknown='ignore'\n",
"\n",
"# Разделение данных на обучающую и тестовую выборки\n",
"X_train, X_test, y_class_train, y_class_test, y_reg_train, y_reg_test = train_test_split(X, y_class, y_reg, test_size=0.2, random_state=42) \n",
"\n",
"def estimate_bias_variance(model, X, y):\n",
" predictions = np.array([model.fit(X, y).predict(X) for _ in range(1000)])\n",
" bias = np.mean((y - np.mean(predictions, axis=0)) ** 2)\n",
" variance = np.mean(np.var(predictions, axis=0))\n",
" return bias, variance"
],
"id": "45df888b49839959",
"outputs": [],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T18:06:01.805307Z",
"start_time": "2024-12-24T18:05:38.651793Z"
}
},
"cell_type": "code",
"source": [
"# Задача регрессии\n",
"reg_pipeline_rf = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('regressor', RandomForestRegressor(random_state=42))])\n",
"\n",
"reg_pipeline_sgd = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('regressor', SGDRegressor(loss='squared_error', penalty='l2', random_state=42, max_iter=2000))])\n",
"\n",
"# Настройка гиперпараметров для регрессии\n",
"param_grid_reg_rf = {\n",
" 'regressor__n_estimators': [100, 200],\n",
" 'regressor__max_depth': [None, 10, 20]}\n",
"\n",
"param_grid_reg_sgd = {\n",
" 'regressor__alpha': [0.0001, 0.001, 0.01],\n",
" 'regressor__learning_rate': ['constant', 'adaptive'],\n",
" 'regressor__eta0': [0.01, 0.1]}\n",
"\n",
"# Поиск гиперпараметров\n",
"grid_search_reg_rf = GridSearchCV(reg_pipeline_rf, param_grid_reg_rf, cv=5, scoring='r2')\n",
"grid_search_reg_rf.fit(X_train, y_reg_train)\n",
"\n",
"grid_search_reg_sgd = GridSearchCV(reg_pipeline_sgd, param_grid_reg_sgd, cv=5, scoring='r2')\n",
"grid_search_reg_sgd.fit(X_train, y_reg_train)"
],
"id": "ee19135bbfc42564",
"outputs": [
{
"data": {
"text/plain": [
"GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[('preprocessor',\n",
" ColumnTransformer(transformers=[('num',\n",
" StandardScaler(),\n",
" ['age',\n",
" 'avg_glucose_level',\n",
" 'bmi']),\n",
" ('cat',\n",
" OneHotEncoder(handle_unknown='ignore'),\n",
" ['gender',\n",
" 'ever_married',\n",
" 'smoking_status'])])),\n",
" ('regressor',\n",
" SGDRegressor(max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={'regressor__alpha': [0.0001, 0.001, 0.01],\n",
" 'regressor__eta0': [0.01, 0.1],\n",
" 'regressor__learning_rate': ['constant', 'adaptive']},\n",
" scoring='r2')"
],
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;,\n",
" StandardScaler(),\n",
" [&#x27;age&#x27;,\n",
" &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;,\n",
" &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;regressor&#x27;,\n",
" SGDRegressor(max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={&#x27;regressor__alpha&#x27;: [0.0001, 0.001, 0.01],\n",
" &#x27;regressor__eta0&#x27;: [0.01, 0.1],\n",
" &#x27;regressor__learning_rate&#x27;: [&#x27;constant&#x27;, &#x27;adaptive&#x27;]},\n",
" scoring=&#x27;r2&#x27;)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;GridSearchCV<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.model_selection.GridSearchCV.html\">?<span>Documentation for GridSearchCV</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>GridSearchCV(cv=5,\n",
" estimator=Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;,\n",
" StandardScaler(),\n",
" [&#x27;age&#x27;,\n",
" &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;,\n",
" &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;regressor&#x27;,\n",
" SGDRegressor(max_iter=2000,\n",
" random_state=42))]),\n",
" param_grid={&#x27;regressor__alpha&#x27;: [0.0001, 0.001, 0.01],\n",
" &#x27;regressor__eta0&#x27;: [0.01, 0.1],\n",
" &#x27;regressor__learning_rate&#x27;: [&#x27;constant&#x27;, &#x27;adaptive&#x27;]},\n",
" scoring=&#x27;r2&#x27;)</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">best_estimator_: Pipeline</label><div class=\"sk-toggleable__content fitted\"><pre>Pipeline(steps=[(&#x27;preprocessor&#x27;,\n",
" ColumnTransformer(transformers=[(&#x27;num&#x27;, StandardScaler(),\n",
" [&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;,\n",
" &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;,\n",
" OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;, &#x27;ever_married&#x27;,\n",
" &#x27;smoking_status&#x27;])])),\n",
" (&#x27;regressor&#x27;,\n",
" SGDRegressor(eta0=0.1, learning_rate=&#x27;adaptive&#x27;, max_iter=2000,\n",
" random_state=42))])</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-serial\"><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;preprocessor: ColumnTransformer<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html\">?<span>Documentation for preprocessor: ColumnTransformer</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>ColumnTransformer(transformers=[(&#x27;num&#x27;, StandardScaler(),\n",
" [&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;, &#x27;bmi&#x27;]),\n",
" (&#x27;cat&#x27;, OneHotEncoder(handle_unknown=&#x27;ignore&#x27;),\n",
" [&#x27;gender&#x27;, &#x27;ever_married&#x27;, &#x27;smoking_status&#x27;])])</pre></div> </div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" ><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">num</label><div class=\"sk-toggleable__content fitted\"><pre>[&#x27;age&#x27;, &#x27;avg_glucose_level&#x27;, &#x27;bmi&#x27;]</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-5\" type=\"checkbox\" ><label for=\"sk-estimator-id-5\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div></div><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-6\" type=\"checkbox\" ><label for=\"sk-estimator-id-6\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">cat</label><div class=\"sk-toggleable__content fitted\"><pre>[&#x27;gender&#x27;, &#x27;ever_married&#x27;, &#x27;smoking_status&#x27;]</pre></div> </div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-7\" type=\"checkbox\" ><label for=\"sk-estimator-id-7\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;OneHotEncoder<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.preprocessing.OneHotEncoder.html\">?<span>Documentation for OneHotEncoder</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>OneHotEncoder(handle_unknown=&#x27;ignore&#x27;)</pre></div> </div></div></div></div></div></div></div><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" ><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;SGDRegressor<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.5/modules/generated/sklearn.linear_model.SGDRegressor.html\">?<span>Documentation for SGDRegressor</span></a></label><div class=\"sk-toggleable__content fitted\"><pre>SGDRegressor(eta0=0.1, learning_rate=&#x27;adaptive&#x27;, max_iter=2000, random_state=42)</pre></div> </div></div></div></div></div></div></div></div></div></div></div>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T18:06:06.657058Z",
"start_time": "2024-12-24T18:06:06.621503Z"
}
},
"cell_type": "code",
"source": [
"# Оценка моделей\n",
"y_reg_pred_rf = grid_search_reg_rf.predict(X_test)\n",
"y_reg_pred_sgd = grid_search_reg_sgd.predict(X_test)\n",
"\n",
"print(\"Regression Metrics for Random Forest:\")\n",
"print(\"Mean Squared Error:\", mean_squared_error(y_reg_test, y_reg_pred_rf))\n",
"print(\"R2 Score:\", r2_score(y_reg_test, y_reg_pred_rf))\n",
"\n",
"print(\"Regression Metrics for SGD:\")\n",
"print(\"Mean Squared Error:\", mean_squared_error(y_reg_test, y_reg_pred_sgd))\n",
"print(\"R2 Score:\", r2_score(y_reg_test, y_reg_pred_sgd))"
],
"id": "65f8397fab8911ba",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regression Metrics for Random Forest:\n",
"Mean Squared Error: 0.020541479842551013\n",
"R2 Score: 0.9999932939237861\n",
"Regression Metrics for SGD:\n",
"Mean Squared Error: 4.648382893338219e-05\n",
"R2 Score: 0.9999999848246522\n"
]
}
],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-24T18:28:20.340519Z",
"start_time": "2024-12-24T18:06:12.807767Z"
}
},
"cell_type": "code",
"source": [
"# Оценка смещения и дисперсии\n",
"bias_reg_rf, variance_reg_rf = estimate_bias_variance(grid_search_reg_rf.best_estimator_, X_train, y_reg_train)\n",
"bias_reg_sgd, variance_reg_sgd = estimate_bias_variance(grid_search_reg_sgd.best_estimator_, X_train, y_reg_train)\n",
"\n",
"print(\"Regression Bias (Random Forest):\", bias_reg_rf)\n",
"print(\"Regression Variance (Random Forest):\", variance_reg_rf)\n",
"print(\"Regression Bias (SGD):\", bias_reg_sgd)\n",
"print(\"Regression Variance (SGD):\", variance_reg_sgd)"
],
"id": "cccd002e6275411f",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Regression Bias (Random Forest): 0.0022602929210874885\n",
"Regression Variance (Random Forest): 3.608883047891326e-24\n",
"Regression Bias (SGD): 4.682701837803326e-05\n",
"Regression Variance (SGD): 3.2443460449162085e-24\n"
]
}
],
"execution_count": 9
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}