diff --git a/lab_4/lab4.ipynb b/lab_4/lab4.ipynb
new file mode 100644
index 0000000..733cd47
--- /dev/null
+++ b/lab_4/lab4.ipynb
@@ -0,0 +1,2391 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Лабораторная 4\n",
+ "\n",
+ "Датасет: Информация об онлайн обучении учеников\n",
+ "\n",
+ "Бизнес-цель 1: Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 104,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Index(['Education Level', 'Institution Type', 'Gender', 'Age', 'Device',\n",
+ " 'IT Student', 'Location', 'Financial Condition', 'Internet Type',\n",
+ " 'Network Type', 'Flexibility Level'],\n",
+ " dtype='object')\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "from typing import Tuple\n",
+ "from pandas import DataFrame\n",
+ "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, metrics, set_config\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.compose import ColumnTransformer\n",
+ "from sklearn.discriminant_analysis import StandardScaler\n",
+ "from sklearn.impute import SimpleImputer\n",
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.preprocessing import OneHotEncoder\n",
+ "from sklearn.compose import ColumnTransformer\n",
+ "from sklearn.discriminant_analysis import StandardScaler\n",
+ "from sklearn.impute import SimpleImputer\n",
+ "from sklearn.pipeline import Pipeline\n",
+ "from sklearn.preprocessing import OneHotEncoder\n",
+ "from sklearn.metrics import ConfusionMatrixDisplay\n",
+ "from sklearn.model_selection import GridSearchCV\n",
+ "\n",
+ "set_config(transform_output=\"pandas\")\n",
+ "df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n",
+ "print(df.columns)\n",
+ "\n",
+ "map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n",
+ "\n",
+ "df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Предварительно создадим колонку для работы с ней (ключевой фактор)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 105,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fincond_mapping = {'Poor': 2, 'Mid': 1, 'Rich': 0}\n",
+ "internet_type_mapping = {'Mobile Data': 1, 'Wifi': 0}\n",
+ "device_mapping = {'Mobile': 1, 'Computer': 0}\n",
+ "network_type = {'2G': 2, '3G': 1, '4G': 0}\n",
+ "\n",
+ "df['Financial Score'] = df['Financial Condition'].map(fincond_mapping)\n",
+ "df['Internet Score'] = df['Internet Type'].map(internet_type_mapping)\n",
+ "df['Device Score'] = df['Device'].map(device_mapping)\n",
+ "df['Network Score'] = df['Network Type'].map(network_type)\n",
+ "\n",
+ "df['Access Difficulty Score'] = df['Financial Score'] + df['Internet Score'] + df['Device Score'] + df['Network Score']\n",
+ "\n",
+ "df['Access Difficulty'] = (df['Access Difficulty Score'] >= 3).astype(int)\n",
+ "df.drop(columns=['Financial Score', 'Device Score', 'Internet Score', 'Network Score', 'Access Difficulty Score'], inplace=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Формируем выборки"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 106,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'X_train'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Education Level | \n",
+ " Institution Type | \n",
+ " Gender | \n",
+ " Age | \n",
+ " Device | \n",
+ " IT Student | \n",
+ " Location | \n",
+ " Financial Condition | \n",
+ " Internet Type | \n",
+ " Network Type | \n",
+ " Flexibility Level | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 649 | \n",
+ " School | \n",
+ " Public | \n",
+ " Male | \n",
+ " 18 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 637 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 9 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 68 | \n",
+ " School | \n",
+ " Public | \n",
+ " Female | \n",
+ " 11 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 276 | \n",
+ " University | \n",
+ " Private | \n",
+ " Female | \n",
+ " 18 | \n",
+ " Mobile | \n",
+ " Yes | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 3G | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 547 | \n",
+ " School | \n",
+ " Public | \n",
+ " Male | \n",
+ " 11 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1097 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Mobile | \n",
+ " Yes | \n",
+ " Town | \n",
+ " Rich | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 854 | \n",
+ " School | \n",
+ " Public | \n",
+ " Female | \n",
+ " 18 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 756 | \n",
+ " University | \n",
+ " Public | \n",
+ " Male | \n",
+ " 18 | \n",
+ " Computer | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 3G | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 133 | \n",
+ " College | \n",
+ " Public | \n",
+ " Male | \n",
+ " 18 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Poor | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " University | \n",
+ " Public | \n",
+ " Male | \n",
+ " 27 | \n",
+ " Mobile | \n",
+ " Yes | \n",
+ " Rural | \n",
+ " Poor | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
964 rows × 12 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Education Level Institution Type Gender Age Device IT Student \\\n",
+ "649 School Public Male 18 Mobile No \n",
+ "637 School Private Female 9 Mobile No \n",
+ "68 School Public Female 11 Mobile No \n",
+ "276 University Private Female 18 Mobile Yes \n",
+ "547 School Public Male 11 Mobile No \n",
+ "... ... ... ... ... ... ... \n",
+ "1097 University Private Male 23 Mobile Yes \n",
+ "854 School Public Female 18 Mobile No \n",
+ "756 University Public Male 18 Computer No \n",
+ "133 College Public Male 18 Mobile No \n",
+ "53 University Public Male 27 Mobile Yes \n",
+ "\n",
+ " Location Financial Condition Internet Type Network Type \\\n",
+ "649 Town Mid Wifi 4G \n",
+ "637 Town Mid Mobile Data 4G \n",
+ "68 Town Mid Wifi 4G \n",
+ "276 Town Mid Mobile Data 3G \n",
+ "547 Town Mid Wifi 4G \n",
+ "... ... ... ... ... \n",
+ "1097 Town Rich Wifi 4G \n",
+ "854 Town Mid Mobile Data 4G \n",
+ "756 Town Mid Wifi 3G \n",
+ "133 Town Poor Mobile Data 4G \n",
+ "53 Rural Poor Mobile Data 4G \n",
+ "\n",
+ " Flexibility Level Access Difficulty \n",
+ "649 1 0 \n",
+ "637 1 1 \n",
+ "68 0 0 \n",
+ "276 0 1 \n",
+ "547 1 0 \n",
+ "... ... ... \n",
+ "1097 0 0 \n",
+ "854 0 1 \n",
+ "756 1 0 \n",
+ "133 0 1 \n",
+ "53 1 1 \n",
+ "\n",
+ "[964 rows x 12 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'y_train'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 649 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 637 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 68 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 276 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 547 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1097 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 854 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 756 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 133 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
964 rows × 1 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Access Difficulty\n",
+ "649 0\n",
+ "637 1\n",
+ "68 0\n",
+ "276 1\n",
+ "547 0\n",
+ "... ...\n",
+ "1097 0\n",
+ "854 1\n",
+ "756 0\n",
+ "133 1\n",
+ "53 1\n",
+ "\n",
+ "[964 rows x 1 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'X_test'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Education Level | \n",
+ " Institution Type | \n",
+ " Gender | \n",
+ " Age | \n",
+ " Device | \n",
+ " IT Student | \n",
+ " Location | \n",
+ " Financial Condition | \n",
+ " Internet Type | \n",
+ " Network Type | \n",
+ " Flexibility Level | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 265 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 9 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Poor | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 358 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 10 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 3G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 316 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Tab | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 907 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 9 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Poor | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1042 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 3G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 421 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 10 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Mobile Data | \n",
+ " 3G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 936 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Tab | \n",
+ " No | \n",
+ " Town | \n",
+ " Rich | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 2 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 722 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Mobile | \n",
+ " Yes | \n",
+ " Rural | \n",
+ " Poor | \n",
+ " Mobile Data | \n",
+ " 3G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1075 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Computer | \n",
+ " Yes | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 577 | \n",
+ " University | \n",
+ " Private | \n",
+ " Male | \n",
+ " 23 | \n",
+ " Mobile | \n",
+ " Yes | \n",
+ " Town | \n",
+ " Mid | \n",
+ " Wifi | \n",
+ " 4G | \n",
+ " 0 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
241 rows × 12 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Education Level Institution Type Gender Age Device IT Student \\\n",
+ "265 School Private Female 9 Mobile No \n",
+ "358 School Private Female 10 Mobile No \n",
+ "316 University Private Male 23 Tab No \n",
+ "907 School Private Female 9 Mobile No \n",
+ "1042 University Private Male 23 Mobile No \n",
+ "... ... ... ... ... ... ... \n",
+ "421 School Private Female 10 Mobile No \n",
+ "936 University Private Male 23 Tab No \n",
+ "722 University Private Male 23 Mobile Yes \n",
+ "1075 University Private Male 23 Computer Yes \n",
+ "577 University Private Male 23 Mobile Yes \n",
+ "\n",
+ " Location Financial Condition Internet Type Network Type \\\n",
+ "265 Town Poor Wifi 4G \n",
+ "358 Town Mid Mobile Data 3G \n",
+ "316 Town Mid Wifi 4G \n",
+ "907 Town Poor Mobile Data 4G \n",
+ "1042 Town Mid Mobile Data 3G \n",
+ "... ... ... ... ... \n",
+ "421 Town Mid Mobile Data 3G \n",
+ "936 Town Rich Wifi 4G \n",
+ "722 Rural Poor Mobile Data 3G \n",
+ "1075 Town Mid Wifi 4G \n",
+ "577 Town Mid Wifi 4G \n",
+ "\n",
+ " Flexibility Level Access Difficulty \n",
+ "265 1 1 \n",
+ "358 1 1 \n",
+ "316 1 0 \n",
+ "907 1 1 \n",
+ "1042 1 1 \n",
+ "... ... ... \n",
+ "421 1 1 \n",
+ "936 2 0 \n",
+ "722 1 1 \n",
+ "1075 0 0 \n",
+ "577 0 0 \n",
+ "\n",
+ "[241 rows x 12 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'y_test'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 265 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 358 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 316 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 907 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1042 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 421 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 936 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 722 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ " 1075 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 577 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
241 rows × 1 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Access Difficulty\n",
+ "265 1\n",
+ "358 1\n",
+ "316 0\n",
+ "907 1\n",
+ "1042 1\n",
+ "... ...\n",
+ "421 1\n",
+ "936 0\n",
+ "722 1\n",
+ "1075 0\n",
+ "577 0\n",
+ "\n",
+ "[241 rows x 1 columns]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "def split_stratified_into_train_val_test(\n",
+ " df_input,\n",
+ " stratify_colname=\"y\",\n",
+ " frac_train=0.6,\n",
+ " frac_val=0.15,\n",
+ " frac_test=0.25,\n",
+ " random_state=None,\n",
+ ") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
+ " \n",
+ " if frac_train + frac_val + frac_test != 1.0:\n",
+ " raise ValueError(\n",
+ " \"fractions %f, %f, %f do not add up to 1.0\"\n",
+ " % (frac_train, frac_val, frac_test)\n",
+ " )\n",
+ " if stratify_colname not in df_input.columns:\n",
+ " raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
+ " X = df_input\n",
+ " y = df_input[\n",
+ " [stratify_colname]\n",
+ " ]\n",
+ " df_train, df_temp, y_train, y_temp = train_test_split(\n",
+ " X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
+ " )\n",
+ " if frac_val <= 0:\n",
+ " assert len(df_input) == len(df_train) + len(df_temp)\n",
+ " return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
+ " \n",
+ " relative_frac_test = frac_test / (frac_val + frac_test)\n",
+ " df_val, df_test, y_val, y_test = train_test_split(\n",
+ " df_temp,\n",
+ " y_temp,\n",
+ " stratify=y_temp,\n",
+ " test_size=relative_frac_test,\n",
+ " random_state=random_state,\n",
+ " )\n",
+ " assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
+ " return df_train, df_val, df_test, y_train, y_val, y_test\n",
+ "\n",
+ "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
+ " df, stratify_colname=\"Access Difficulty\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
+ ")\n",
+ "\n",
+ "display(\"X_train\", X_train)\n",
+ "display(\"y_train\", y_train)\n",
+ "\n",
+ "display(\"X_test\", X_test)\n",
+ "display(\"y_test\", y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Пропущенные значения по столбцам:\n",
+ "Education Level 0\n",
+ "Institution Type 0\n",
+ "Gender 0\n",
+ "Age 0\n",
+ "Device 0\n",
+ "IT Student 0\n",
+ "Location 0\n",
+ "Financial Condition 0\n",
+ "Internet Type 0\n",
+ "Network Type 0\n",
+ "Flexibility Level 0\n",
+ "Access Difficulty 0\n",
+ "dtype: int64\n",
+ "\n",
+ "Статистический обзор данных:\n",
+ " Age Flexibility Level Access Difficulty\n",
+ "count 1205.000000 1205.000000 1205.000000\n",
+ "mean 17.065560 0.684647 0.624896\n",
+ "std 5.830369 0.618221 0.484351\n",
+ "min 9.000000 0.000000 0.000000\n",
+ "25% 11.000000 0.000000 0.000000\n",
+ "50% 18.000000 1.000000 1.000000\n",
+ "75% 23.000000 1.000000 1.000000\n",
+ "max 27.000000 2.000000 1.000000\n"
+ ]
+ }
+ ],
+ "source": [
+ "null_values = df.isnull().sum()\n",
+ "print(\"Пропущенные значения по столбцам:\")\n",
+ "print(null_values)\n",
+ "\n",
+ "stat_summary = df.describe()\n",
+ "print(\"\\nСтатистический обзор данных:\")\n",
+ "print(stat_summary)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Формируем конвеер для классификации данных и проверка конвеера"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 108,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Access Difficulty | \n",
+ " Institution Type_Public | \n",
+ " Device_Mobile | \n",
+ " Device_Tab | \n",
+ " Location_Town | \n",
+ " Financial Condition_Poor | \n",
+ " Financial Condition_Rich | \n",
+ " Internet Type_Wifi | \n",
+ " Network Type_3G | \n",
+ " Network Type_4G | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 649 | \n",
+ " -1.289567 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 637 | \n",
+ " 0.775454 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 68 | \n",
+ " -1.289567 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 276 | \n",
+ " 0.775454 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 547 | \n",
+ " -1.289567 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 1097 | \n",
+ " -1.289567 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 854 | \n",
+ " 0.775454 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 756 | \n",
+ " -1.289567 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " 133 | \n",
+ " 0.775454 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " 0.775454 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
964 rows × 10 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
+ "649 -1.289567 1.0 1.0 0.0 \n",
+ "637 0.775454 0.0 1.0 0.0 \n",
+ "68 -1.289567 1.0 1.0 0.0 \n",
+ "276 0.775454 0.0 1.0 0.0 \n",
+ "547 -1.289567 1.0 1.0 0.0 \n",
+ "... ... ... ... ... \n",
+ "1097 -1.289567 0.0 1.0 0.0 \n",
+ "854 0.775454 1.0 1.0 0.0 \n",
+ "756 -1.289567 1.0 0.0 0.0 \n",
+ "133 0.775454 1.0 1.0 0.0 \n",
+ "53 0.775454 1.0 1.0 0.0 \n",
+ "\n",
+ " Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
+ "649 1.0 0.0 0.0 \n",
+ "637 1.0 0.0 0.0 \n",
+ "68 1.0 0.0 0.0 \n",
+ "276 1.0 0.0 0.0 \n",
+ "547 1.0 0.0 0.0 \n",
+ "... ... ... ... \n",
+ "1097 1.0 0.0 1.0 \n",
+ "854 1.0 0.0 0.0 \n",
+ "756 1.0 0.0 0.0 \n",
+ "133 1.0 1.0 0.0 \n",
+ "53 0.0 1.0 0.0 \n",
+ "\n",
+ " Internet Type_Wifi Network Type_3G Network Type_4G \n",
+ "649 1.0 0.0 1.0 \n",
+ "637 0.0 0.0 1.0 \n",
+ "68 1.0 0.0 1.0 \n",
+ "276 0.0 1.0 0.0 \n",
+ "547 1.0 0.0 1.0 \n",
+ "... ... ... ... \n",
+ "1097 1.0 0.0 1.0 \n",
+ "854 0.0 0.0 1.0 \n",
+ "756 1.0 1.0 0.0 \n",
+ "133 0.0 0.0 1.0 \n",
+ "53 0.0 0.0 1.0 \n",
+ "\n",
+ "[964 rows x 10 columns]"
+ ]
+ },
+ "execution_count": 108,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "columns_to_drop = ['Age', 'Education Level', 'Gender', 'IT Student', 'Flexibility Level']\n",
+ "num_columns = [\n",
+ " column\n",
+ " for column in df.columns\n",
+ " if column not in columns_to_drop and df[column].dtype != \"object\"\n",
+ "]\n",
+ "cat_columns = [\n",
+ " column\n",
+ " for column in df.columns\n",
+ " if column not in columns_to_drop and df[column].dtype == \"object\"\n",
+ "]\n",
+ "\n",
+ "num_imputer = SimpleImputer(strategy=\"median\")\n",
+ "num_scaler = StandardScaler()\n",
+ "preprocessing_num = Pipeline(\n",
+ " [\n",
+ " (\"imputer\", num_imputer),\n",
+ " (\"scaler\", num_scaler),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
+ "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
+ "preprocessing_cat = Pipeline(\n",
+ " [\n",
+ " (\"imputer\", cat_imputer),\n",
+ " (\"encoder\", cat_encoder),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "features_preprocessing = ColumnTransformer(\n",
+ " verbose_feature_names_out=False,\n",
+ " transformers=[\n",
+ " (\"prepocessing_num\", preprocessing_num, num_columns),\n",
+ " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
+ " ],\n",
+ " remainder=\"passthrough\"\n",
+ ")\n",
+ "\n",
+ "drop_columns = ColumnTransformer(\n",
+ " verbose_feature_names_out=False,\n",
+ " transformers=[\n",
+ " (\"drop_columns\", \"drop\", columns_to_drop),\n",
+ " ],\n",
+ " remainder=\"passthrough\",\n",
+ ")\n",
+ "\n",
+ "\n",
+ "pipeline_end = Pipeline(\n",
+ " [\n",
+ " (\"features_preprocessing\", features_preprocessing),\n",
+ " (\"drop_columns\", drop_columns),\n",
+ " ]\n",
+ ")\n",
+ "\n",
+ "preprocessing_result = pipeline_end.fit_transform(X_train)\n",
+ "preprocessed_df = pd.DataFrame(\n",
+ " preprocessing_result,\n",
+ " columns=pipeline_end.get_feature_names_out(),\n",
+ ")\n",
+ "\n",
+ "preprocessed_df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Формируем набор моделей"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 109,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class_models = {\n",
+ " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
+ " \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
+ " \"decision_tree\": {\n",
+ " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
+ " },\n",
+ " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
+ " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
+ " \"gradient_boosting\": {\n",
+ " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
+ " },\n",
+ " \"random_forest\": {\n",
+ " \"model\": ensemble.RandomForestClassifier(\n",
+ " max_depth=11, class_weight=\"balanced\", random_state=9\n",
+ " )\n",
+ " },\n",
+ " \"mlp\": {\n",
+ " \"model\": neural_network.MLPClassifier(\n",
+ " hidden_layer_sizes=(7,),\n",
+ " max_iter=500,\n",
+ " early_stopping=True,\n",
+ " random_state=9,\n",
+ " )\n",
+ " },\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Обучаем модели и тестируем их"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Model: logistic\n",
+ "Model: ridge\n",
+ "Model: decision_tree\n",
+ "Model: knn\n",
+ "Model: naive_bayes\n",
+ "Model: gradient_boosting\n",
+ "Model: random_forest\n",
+ "Model: mlp\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
+ "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
+ " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
+ ]
+ }
+ ],
+ "source": [
+ "for model_name in class_models.keys():\n",
+ " print(f\"Model: {model_name}\")\n",
+ " model = class_models[model_name][\"model\"]\n",
+ "\n",
+ " model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
+ " model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
+ "\n",
+ " y_train_predict = model_pipeline.predict(X_train)\n",
+ " y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
+ " y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
+ "\n",
+ " class_models[model_name][\"pipeline\"] = model_pipeline\n",
+ " class_models[model_name][\"probs\"] = y_test_probs\n",
+ " class_models[model_name][\"preds\"] = y_test_predict\n",
+ "\n",
+ " class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
+ " y_train, y_train_predict\n",
+ " )\n",
+ " class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
+ " y_test, y_test_predict\n",
+ " )\n",
+ " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
+ " y_train, y_train_predict\n",
+ " )\n",
+ " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
+ " y_test, y_test_predict\n",
+ " )\n",
+ " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
+ " y_train, y_train_predict\n",
+ " )\n",
+ " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
+ " y_test, y_test_predict\n",
+ " )\n",
+ " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
+ " y_test, y_test_probs\n",
+ " )\n",
+ " class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
+ " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
+ " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
+ " y_test, y_test_predict\n",
+ " )\n",
+ " class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
+ " y_test, y_test_predict\n",
+ " )\n",
+ " class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
+ " y_test, y_test_predict\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 111,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
+ "for index, key in enumerate(class_models.keys()):\n",
+ " c_matrix = class_models[key][\"Confusion_matrix\"]\n",
+ " disp = ConfusionMatrixDisplay(\n",
+ " confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n",
+ " ).plot(ax=ax.flat[index])\n",
+ " disp.ax_.set_title(key)\n",
+ "\n",
+ "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Точность, полнота, верность (аккуратность), F-мера"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 112,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Precision_train | \n",
+ " Precision_test | \n",
+ " Recall_train | \n",
+ " Recall_test | \n",
+ " Accuracy_train | \n",
+ " Accuracy_test | \n",
+ " F1_train | \n",
+ " F1_test | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " logistic | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " ridge | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " decision_tree | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " knn | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " naive_bayes | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " gradient_boosting | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " random_forest | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.000000 | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " mlp | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.375519 | \n",
+ " 0.373444 | \n",
+ " [0.5460030165912518, 0.0] | \n",
+ " [0.5438066465256798, 0.0] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Precision_train Precision_test Recall_train Recall_test \\\n",
+ "logistic 1.0 1.0 1.0 1.0 \n",
+ "ridge 1.0 1.0 1.0 1.0 \n",
+ "decision_tree 1.0 1.0 1.0 1.0 \n",
+ "knn 1.0 1.0 1.0 1.0 \n",
+ "naive_bayes 1.0 1.0 1.0 1.0 \n",
+ "gradient_boosting 1.0 1.0 1.0 1.0 \n",
+ "random_forest 1.0 1.0 1.0 1.0 \n",
+ "mlp 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ " Accuracy_train Accuracy_test F1_train \\\n",
+ "logistic 1.000000 1.000000 [1.0, 1.0] \n",
+ "ridge 1.000000 1.000000 [1.0, 1.0] \n",
+ "decision_tree 1.000000 1.000000 [1.0, 1.0] \n",
+ "knn 1.000000 1.000000 [1.0, 1.0] \n",
+ "naive_bayes 1.000000 1.000000 [1.0, 1.0] \n",
+ "gradient_boosting 1.000000 1.000000 [1.0, 1.0] \n",
+ "random_forest 1.000000 1.000000 [1.0, 1.0] \n",
+ "mlp 0.375519 0.373444 [0.5460030165912518, 0.0] \n",
+ "\n",
+ " F1_test \n",
+ "logistic [1.0, 1.0] \n",
+ "ridge [1.0, 1.0] \n",
+ "decision_tree [1.0, 1.0] \n",
+ "knn [1.0, 1.0] \n",
+ "naive_bayes [1.0, 1.0] \n",
+ "gradient_boosting [1.0, 1.0] \n",
+ "random_forest [1.0, 1.0] \n",
+ "mlp [0.5438066465256798, 0.0] "
+ ]
+ },
+ "execution_count": 112,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
+ " [\n",
+ " \"Precision_train\",\n",
+ " \"Precision_test\",\n",
+ " \"Recall_train\",\n",
+ " \"Recall_test\",\n",
+ " \"Accuracy_train\",\n",
+ " \"Accuracy_test\",\n",
+ " \"F1_train\",\n",
+ " \"F1_test\",\n",
+ " ]\n",
+ "]\n",
+ "class_metrics.sort_values(\n",
+ " by=\"Accuracy_test\", ascending=False\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 113,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Accuracy_test | \n",
+ " F1_test | \n",
+ " ROC_AUC_test | \n",
+ " Cohen_kappa_test | \n",
+ " MCC_test | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " logistic | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " ridge | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " decision_tree | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " knn | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " naive_bayes | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " gradient_boosting | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " random_forest | \n",
+ " 1.000000 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.000000 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " mlp | \n",
+ " 0.373444 | \n",
+ " [0.5438066465256798, 0.0] | \n",
+ " 0.068065 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Accuracy_test F1_test ROC_AUC_test \\\n",
+ "logistic 1.000000 [1.0, 1.0] 1.000000 \n",
+ "ridge 1.000000 [1.0, 1.0] 1.000000 \n",
+ "decision_tree 1.000000 [1.0, 1.0] 1.000000 \n",
+ "knn 1.000000 [1.0, 1.0] 1.000000 \n",
+ "naive_bayes 1.000000 [1.0, 1.0] 1.000000 \n",
+ "gradient_boosting 1.000000 [1.0, 1.0] 1.000000 \n",
+ "random_forest 1.000000 [1.0, 1.0] 1.000000 \n",
+ "mlp 0.373444 [0.5438066465256798, 0.0] 0.068065 \n",
+ "\n",
+ " Cohen_kappa_test MCC_test \n",
+ "logistic 1.0 1.0 \n",
+ "ridge 1.0 1.0 \n",
+ "decision_tree 1.0 1.0 \n",
+ "knn 1.0 1.0 \n",
+ "naive_bayes 1.0 1.0 \n",
+ "gradient_boosting 1.0 1.0 \n",
+ "random_forest 1.0 1.0 \n",
+ "mlp 0.0 0.0 "
+ ]
+ },
+ "execution_count": 113,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
+ " [\n",
+ " \"Accuracy_test\",\n",
+ " \"F1_test\",\n",
+ " \"ROC_AUC_test\",\n",
+ " \"Cohen_kappa_test\",\n",
+ " \"MCC_test\",\n",
+ " ]\n",
+ "]\n",
+ "class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Лучшая модель"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 114,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'logistic'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
+ "\n",
+ "display(best_model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Находим ошибки"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 115,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Error items count: 0'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Education Level | \n",
+ " Predicted | \n",
+ " Institution Type | \n",
+ " Gender | \n",
+ " Age | \n",
+ " Device | \n",
+ " IT Student | \n",
+ " Location | \n",
+ " Financial Condition | \n",
+ " Internet Type | \n",
+ " Network Type | \n",
+ " Flexibility Level | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ "Empty DataFrame\n",
+ "Columns: [Education Level, Predicted, Institution Type, Gender, Age, Device, IT Student, Location, Financial Condition, Internet Type, Network Type, Flexibility Level, Access Difficulty]\n",
+ "Index: []"
+ ]
+ },
+ "execution_count": 115,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "preprocessing_result = pipeline_end.transform(X_test)\n",
+ "preprocessed_df = pd.DataFrame(\n",
+ " preprocessing_result,\n",
+ " columns=pipeline_end.get_feature_names_out(),\n",
+ ")\n",
+ "\n",
+ "y_pred = class_models[best_model][\"preds\"]\n",
+ "\n",
+ "error_index = y_test[y_test[\"Access Difficulty\"] != y_pred].index.tolist()\n",
+ "display(f\"Error items count: {len(error_index)}\")\n",
+ "\n",
+ "error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
+ "error_df = X_test.loc[error_index].copy()\n",
+ "error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
+ "error_df.sort_index()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Пример использования модели (конвейера) для предсказания"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 116,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Education Level | \n",
+ " Institution Type | \n",
+ " Gender | \n",
+ " Age | \n",
+ " Device | \n",
+ " IT Student | \n",
+ " Location | \n",
+ " Financial Condition | \n",
+ " Internet Type | \n",
+ " Network Type | \n",
+ " Flexibility Level | \n",
+ " Access Difficulty | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 450 | \n",
+ " School | \n",
+ " Private | \n",
+ " Female | \n",
+ " 11 | \n",
+ " Mobile | \n",
+ " No | \n",
+ " Town | \n",
+ " Poor | \n",
+ " Mobile Data | \n",
+ " 4G | \n",
+ " 1 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Education Level Institution Type Gender Age Device IT Student Location \\\n",
+ "450 School Private Female 11 Mobile No Town \n",
+ "\n",
+ " Financial Condition Internet Type Network Type Flexibility Level \\\n",
+ "450 Poor Mobile Data 4G 1 \n",
+ "\n",
+ " Access Difficulty \n",
+ "450 1 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Access Difficulty | \n",
+ " Institution Type_Public | \n",
+ " Device_Mobile | \n",
+ " Device_Tab | \n",
+ " Location_Town | \n",
+ " Financial Condition_Poor | \n",
+ " Financial Condition_Rich | \n",
+ " Internet Type_Wifi | \n",
+ " Network Type_3G | \n",
+ " Network Type_4G | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 450 | \n",
+ " 0.775454 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
+ "450 0.775454 0.0 1.0 0.0 \n",
+ "\n",
+ " Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
+ "450 1.0 1.0 0.0 \n",
+ "\n",
+ " Internet Type_Wifi Network Type_3G Network Type_4G \n",
+ "450 0.0 0.0 1.0 "
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'predicted: 1 (proba: [0.00310819 0.99689181])'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "'real: 1'"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "model = class_models[best_model][\"pipeline\"]\n",
+ "\n",
+ "example_id = 450\n",
+ "test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
+ "test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
+ "display(test)\n",
+ "display(test_preprocessed)\n",
+ "result_proba = model.predict_proba(test)[0]\n",
+ "result = model.predict(test)[0]\n",
+ "real = int(y_test.loc[example_id].values[0])\n",
+ "display(f\"predicted: {result} (proba: {result_proba})\")\n",
+ "display(f\"real: {real}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Создаем гиперпараметры методом поиска по сетке. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 121,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'model__criterion': 'gini',\n",
+ " 'model__max_depth': 2,\n",
+ " 'model__max_features': 'sqrt',\n",
+ " 'model__n_estimators': 10}"
+ ]
+ },
+ "execution_count": 121,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "optimized_model_type = 'random_forest'\n",
+ "random_state = 9\n",
+ "\n",
+ "random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
+ "\n",
+ "param_grid = {\n",
+ " \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
+ " \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
+ " \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
+ " \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
+ "}\n",
+ "\n",
+ "gs_optomizer = GridSearchCV(\n",
+ " estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
+ ")\n",
+ "gs_optomizer.fit(X_train, y_train.values.ravel())\n",
+ "gs_optomizer.best_params_"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Обучение модели с новыми гиперпараметрами"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 122,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "optimized_model = ensemble.RandomForestClassifier(\n",
+ " random_state=random_state,\n",
+ " criterion=\"gini\",\n",
+ " max_depth=2,\n",
+ " max_features=\"sqrt\",\n",
+ " n_estimators=10,\n",
+ ")\n",
+ "\n",
+ "result = {}\n",
+ "\n",
+ "result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
+ "result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
+ "result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
+ "result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
+ "\n",
+ "result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
+ "result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
+ "result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
+ "result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
+ "result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
+ "result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
+ "result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
+ "result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
+ "result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
+ "result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
+ "result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
+ "result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Формирование данных для оценки старой и новой версии модели и сама оценка данных"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 124,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Precision_train | \n",
+ " Precision_test | \n",
+ " Recall_train | \n",
+ " Recall_test | \n",
+ " Accuracy_train | \n",
+ " Accuracy_test | \n",
+ " F1_train | \n",
+ " F1_test | \n",
+ "
\n",
+ " \n",
+ " Name | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Old | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " [1.0, 1.0] | \n",
+ " [1.0, 1.0] | \n",
+ "
\n",
+ " \n",
+ " New | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
+ "Name \n",
+ "Old 1.0 1.0 1.0 1.0 1.0 \n",
+ "New 1.0 1.0 1.0 1.0 1.0 \n",
+ "\n",
+ " Accuracy_test F1_train F1_test \n",
+ "Name \n",
+ "Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
+ "New 1.0 1.0 1.0 "
+ ]
+ },
+ "execution_count": 124,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
+ "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
+ " data=class_models[optimized_model_type]\n",
+ ")\n",
+ "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
+ " data=result\n",
+ ")\n",
+ "optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
+ "optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
+ "\n",
+ "optimized_metrics[\n",
+ " [\n",
+ " \"Precision_train\",\n",
+ " \"Precision_test\",\n",
+ " \"Recall_train\",\n",
+ " \"Recall_test\",\n",
+ " \"Accuracy_train\",\n",
+ " \"Accuracy_test\",\n",
+ " \"F1_train\",\n",
+ " \"F1_test\",\n",
+ " ]\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 125,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Accuracy_test | \n",
+ " F1_test | \n",
+ " ROC_AUC_test | \n",
+ " Cohen_kappa_test | \n",
+ " MCC_test | \n",
+ "
\n",
+ " \n",
+ " Name | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Old | \n",
+ " 1.0 | \n",
+ " [1.0, 1.0] | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " New | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
+ "Name \n",
+ "Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
+ "New 1.0 1.0 1.0 1.0 1.0"
+ ]
+ },
+ "execution_count": 125,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "optimized_metrics[\n",
+ " [\n",
+ " \"Accuracy_test\",\n",
+ " \"F1_test\",\n",
+ " \"ROC_AUC_test\",\n",
+ " \"Cohen_kappa_test\",\n",
+ " \"MCC_test\",\n",
+ " ]\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 127,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "