AIM-PIbd-31-Makarov-DV/lab_4/lab4.ipynb

2392 lines
207 KiB
Plaintext
Raw Normal View History

2024-11-15 00:44:23 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная 4\n",
"\n",
"Датасет: Информация об онлайн обучении учеников\n",
"\n",
"Бизнес-цель 1: Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения."
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Education Level', 'Institution Type', 'Gender', 'Age', 'Device',\n",
" 'IT Student', 'Location', 'Financial Condition', 'Internet Type',\n",
" 'Network Type', 'Flexibility Level'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from typing import Tuple\n",
"from pandas import DataFrame\n",
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, metrics, set_config\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n",
"print(df.columns)\n",
"\n",
"map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n",
"\n",
"df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Предварительно создадим колонку для работы с ней (ключевой фактор)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
"fincond_mapping = {'Poor': 2, 'Mid': 1, 'Rich': 0}\n",
"internet_type_mapping = {'Mobile Data': 1, 'Wifi': 0}\n",
"device_mapping = {'Mobile': 1, 'Computer': 0}\n",
"network_type = {'2G': 2, '3G': 1, '4G': 0}\n",
"\n",
"df['Financial Score'] = df['Financial Condition'].map(fincond_mapping)\n",
"df['Internet Score'] = df['Internet Type'].map(internet_type_mapping)\n",
"df['Device Score'] = df['Device'].map(device_mapping)\n",
"df['Network Score'] = df['Network Type'].map(network_type)\n",
"\n",
"df['Access Difficulty Score'] = df['Financial Score'] + df['Internet Score'] + df['Device Score'] + df['Network Score']\n",
"\n",
"df['Access Difficulty'] = (df['Access Difficulty Score'] >= 3).astype(int)\n",
"df.drop(columns=['Financial Score', 'Device Score', 'Internet Score', 'Network Score', 'Access Difficulty Score'], inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>University</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Computer</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>College</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>University</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>27</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"649 School Public Male 18 Mobile No \n",
"637 School Private Female 9 Mobile No \n",
"68 School Public Female 11 Mobile No \n",
"276 University Private Female 18 Mobile Yes \n",
"547 School Public Male 11 Mobile No \n",
"... ... ... ... ... ... ... \n",
"1097 University Private Male 23 Mobile Yes \n",
"854 School Public Female 18 Mobile No \n",
"756 University Public Male 18 Computer No \n",
"133 College Public Male 18 Mobile No \n",
"53 University Public Male 27 Mobile Yes \n",
"\n",
" Location Financial Condition Internet Type Network Type \\\n",
"649 Town Mid Wifi 4G \n",
"637 Town Mid Mobile Data 4G \n",
"68 Town Mid Wifi 4G \n",
"276 Town Mid Mobile Data 3G \n",
"547 Town Mid Wifi 4G \n",
"... ... ... ... ... \n",
"1097 Town Rich Wifi 4G \n",
"854 Town Mid Mobile Data 4G \n",
"756 Town Mid Wifi 3G \n",
"133 Town Poor Mobile Data 4G \n",
"53 Rural Poor Mobile Data 4G \n",
"\n",
" Flexibility Level Access Difficulty \n",
"649 1 0 \n",
"637 1 1 \n",
"68 0 0 \n",
"276 0 1 \n",
"547 1 0 \n",
"... ... ... \n",
"1097 0 0 \n",
"854 0 1 \n",
"756 1 0 \n",
"133 0 1 \n",
"53 1 1 \n",
"\n",
"[964 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty\n",
"649 0\n",
"637 1\n",
"68 0\n",
"276 1\n",
"547 0\n",
"... ...\n",
"1097 0\n",
"854 1\n",
"756 0\n",
"133 1\n",
"53 1\n",
"\n",
"[964 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>265</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>316</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Tab</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>907</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1042</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>936</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Tab</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>722</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1075</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"265 School Private Female 9 Mobile No \n",
"358 School Private Female 10 Mobile No \n",
"316 University Private Male 23 Tab No \n",
"907 School Private Female 9 Mobile No \n",
"1042 University Private Male 23 Mobile No \n",
"... ... ... ... ... ... ... \n",
"421 School Private Female 10 Mobile No \n",
"936 University Private Male 23 Tab No \n",
"722 University Private Male 23 Mobile Yes \n",
"1075 University Private Male 23 Computer Yes \n",
"577 University Private Male 23 Mobile Yes \n",
"\n",
" Location Financial Condition Internet Type Network Type \\\n",
"265 Town Poor Wifi 4G \n",
"358 Town Mid Mobile Data 3G \n",
"316 Town Mid Wifi 4G \n",
"907 Town Poor Mobile Data 4G \n",
"1042 Town Mid Mobile Data 3G \n",
"... ... ... ... ... \n",
"421 Town Mid Mobile Data 3G \n",
"936 Town Rich Wifi 4G \n",
"722 Rural Poor Mobile Data 3G \n",
"1075 Town Mid Wifi 4G \n",
"577 Town Mid Wifi 4G \n",
"\n",
" Flexibility Level Access Difficulty \n",
"265 1 1 \n",
"358 1 1 \n",
"316 1 0 \n",
"907 1 1 \n",
"1042 1 1 \n",
"... ... ... \n",
"421 1 1 \n",
"936 2 0 \n",
"722 1 1 \n",
"1075 0 0 \n",
"577 0 0 \n",
"\n",
"[241 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>265</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>316</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>907</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1042</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>936</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>722</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1075</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty\n",
"265 1\n",
"358 1\n",
"316 0\n",
"907 1\n",
"1042 1\n",
"... ...\n",
"421 1\n",
"936 0\n",
"722 1\n",
"1075 0\n",
"577 0\n",
"\n",
"[241 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ]\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" \n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Access Difficulty\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пропущенные значения по столбцам:\n",
"Education Level 0\n",
"Institution Type 0\n",
"Gender 0\n",
"Age 0\n",
"Device 0\n",
"IT Student 0\n",
"Location 0\n",
"Financial Condition 0\n",
"Internet Type 0\n",
"Network Type 0\n",
"Flexibility Level 0\n",
"Access Difficulty 0\n",
"dtype: int64\n",
"\n",
"Статистический обзор данных:\n",
" Age Flexibility Level Access Difficulty\n",
"count 1205.000000 1205.000000 1205.000000\n",
"mean 17.065560 0.684647 0.624896\n",
"std 5.830369 0.618221 0.484351\n",
"min 9.000000 0.000000 0.000000\n",
"25% 11.000000 0.000000 0.000000\n",
"50% 18.000000 1.000000 1.000000\n",
"75% 23.000000 1.000000 1.000000\n",
"max 27.000000 2.000000 1.000000\n"
]
}
],
"source": [
"null_values = df.isnull().sum()\n",
"print(\"Пропущенные значения по столбцам:\")\n",
"print(null_values)\n",
"\n",
"stat_summary = df.describe()\n",
"print(\"\\nСтатистический обзор данных:\")\n",
"print(stat_summary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем конвеер для классификации данных и проверка конвеера"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" <th>Institution Type_Public</th>\n",
" <th>Device_Mobile</th>\n",
" <th>Device_Tab</th>\n",
" <th>Location_Town</th>\n",
" <th>Financial Condition_Poor</th>\n",
" <th>Financial Condition_Rich</th>\n",
" <th>Internet Type_Wifi</th>\n",
" <th>Network Type_3G</th>\n",
" <th>Network Type_4G</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>-1.289567</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
"649 -1.289567 1.0 1.0 0.0 \n",
"637 0.775454 0.0 1.0 0.0 \n",
"68 -1.289567 1.0 1.0 0.0 \n",
"276 0.775454 0.0 1.0 0.0 \n",
"547 -1.289567 1.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"1097 -1.289567 0.0 1.0 0.0 \n",
"854 0.775454 1.0 1.0 0.0 \n",
"756 -1.289567 1.0 0.0 0.0 \n",
"133 0.775454 1.0 1.0 0.0 \n",
"53 0.775454 1.0 1.0 0.0 \n",
"\n",
" Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
"649 1.0 0.0 0.0 \n",
"637 1.0 0.0 0.0 \n",
"68 1.0 0.0 0.0 \n",
"276 1.0 0.0 0.0 \n",
"547 1.0 0.0 0.0 \n",
"... ... ... ... \n",
"1097 1.0 0.0 1.0 \n",
"854 1.0 0.0 0.0 \n",
"756 1.0 0.0 0.0 \n",
"133 1.0 1.0 0.0 \n",
"53 0.0 1.0 0.0 \n",
"\n",
" Internet Type_Wifi Network Type_3G Network Type_4G \n",
"649 1.0 0.0 1.0 \n",
"637 0.0 0.0 1.0 \n",
"68 1.0 0.0 1.0 \n",
"276 0.0 1.0 0.0 \n",
"547 1.0 0.0 1.0 \n",
"... ... ... ... \n",
"1097 1.0 0.0 1.0 \n",
"854 0.0 0.0 1.0 \n",
"756 1.0 1.0 0.0 \n",
"133 0.0 0.0 1.0 \n",
"53 0.0 0.0 1.0 \n",
"\n",
"[964 rows x 10 columns]"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"columns_to_drop = ['Age', 'Education Level', 'Gender', 'IT Student', 'Flexibility Level']\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем набор моделей"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модели и тестируем их"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2QAAAQ9CAYAAAA2zo55AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxU5f4H8M9hG5AdlS0IcUPNBdM0XFEp1HtNs3tRs6so6M9Kc4lKbyXgVduvS7m0KVpYeVvMvKZXS9yuaS6YGm6oiQpuKIiKwszz+4PL1MgMMjLDzHnm876v87rynGfOeWbQ8+l75jnnKEIIASIiIiIiIqpzTrYeABERERERkaNiQUZERERERGQjLMiIiIiIiIhshAUZERERERGRjbAgIyIiIiIishEWZERERERERDbCgoyIiIiIiMhGWJARERERERHZCAsyIiIiIiIiG2FBRtLIyMiAoig4deqUVbZ/6tQpKIqCjIwMi2wvKysLiqIgKyvLItsjIiKSSVpaGhRFqVFfRVGQlpZm3QERWQkLMiIrW7hwocWKOCIiIiKSi4utB0CkFhEREbh58yZcXV3Net3ChQvRoEEDJCYmGrT36NEDN2/ehJubmwVHSUREJIdXXnkFU6dOtfUwiKyOBRlRDSmKAnd3d4ttz8nJyaLbIyIiksX169fh6ekJFxf+pyrJj1MWSWoLFy7EAw88AI1Gg9DQUDz77LO4evVqlX4LFixA48aN4eHhgU6dOmHr1q2IjY1FbGysvo+xa8gKCgowatQohIWFQaPRICQkBAMHDtRfx9aoUSMcOnQImzdvhqIoUBRFv01T15Dt3LkT/fv3h7+/Pzw9PdG2bVvMmzfPsh8MERGRnai8VuzXX3/Fk08+CX9/f3Tr1s3oNWS3bt3C5MmT0bBhQ3h7e+Oxxx7DmTNnjG43KysLHTt2hLu7O5o0aYL333/f5HVpn376KTp06AAPDw8EBARg6NChyMvLs8r7JboTTzuQtNLS0pCeno64uDg8/fTTOHLkCBYtWoSff/4Z27dv1089XLRoEcaPH4/u3btj8uTJOHXqFAYNGgR/f3+EhYVVu48nnngChw4dwoQJE9CoUSNcuHABGzZswOnTp9GoUSPMnTsXEyZMgJeXF15++WUAQFBQkMntbdiwAX/+858REhKCiRMnIjg4GDk5OVizZg0mTpxouQ+HiIjIzvz1r39Fs2bNMHv2bAghcOHChSp9kpOT8emnn+LJJ59Ely5d8OOPP+JPf/pTlX779u1D3759ERISgvT0dGi1WsyYMQMNGzas0nfWrFl49dVXkZCQgOTkZFy8eBHvvvsuevTogX379sHPz88ab5fod4JIEkuXLhUAxMmTJ8WFCxeEm5ubePTRR4VWq9X3ee+99wQAsWTJEiGEELdu3RL169cXDz30kCgrK9P3y8jIEABEz5499W0nT54UAMTSpUuFEEJcuXJFABBvvfVWteN64IEHDLZTadOmTQKA2LRpkxBCiPLychEZGSkiIiLElStXDPrqdLqafxBEREQqkpqaKgCIYcOGGW2vlJ2dLQCIZ555xqDfk08+KQCI1NRUfduAAQNEvXr1xNmzZ/Vtx44dEy4uLgbbPHXqlHB2dhazZs0y2OaBAweEi4tLlXYia+CURZLSxo0bcfv2bUyaNAlOTr//NR8zZgx8fHzw73//GwCwe/duXL58GWPGjDGYpz58+HD4+/tXuw8PDw+4ubkhKysLV65cqfWY9+3bh5MnT2LSpElVzsbV9La/REREajVu3Lhq169duxYA8Nxzzxm0T5o0yeBnrVaLjRs3YtCgQQgNDdW3N23aFP369TPo+/XXX0On0yEhIQGXLl3SL8HBwWjWrBk2bdpUi3dEVDOcskhS+u233wAAUVFRBu1ubm5o3Lixfn3l/zdt2tSgn4uLCxo1alTtPjQaDd544w08//zzCAoKwsMPP4w///nPGDFiBIKDg80ec25uLgCgdevWZr+WiIhI7SIjI6td/9tvv8HJyQlNmjQxaL8z6y9cuICbN29WyXagat4fO3YMQgg0a9bM6D7NvbMy0b1gQUZUC5MmTcKAAQOwatUqrF+/Hq+++ipee+01/Pjjj2jfvr2th0dERKQaHh4edb5PnU4HRVHw/fffw9nZucp6Ly+vOh8TOR5OWSQpRUREAACOHDli0H779m2cPHlSv77y/48fP27Qr7y8XH+nxLtp0qQJnn/+efznP//BwYMHcfv2bbzzzjv69TWdblh5xu/gwYM16k9ERORIIiIioNPp9DNKKt2Z9YGBgXB3d6+S7UDVvG/SpAmEEIiMjERcXFyV5eGHH7b8GyG6AwsyklJcXBzc3Nwwf/58CCH07R9//DGKior0d2Tq2LEj6tevjw8//BDl5eX6fpmZmXe9LuzGjRsoLS01aGvSpAm8vb1x69YtfZunp6fRW+3f6cEHH0RkZCTmzp1bpf8f3wMREZEjqrz+a/78+Qbtc+fONfjZ2dkZcXFxWLVqFc6dO6dvP378OL7//nuDvoMHD4azszPS09OrZK0QApcvX7bgOyAyjlMWSUoNGzbEtGnTkJ6ejr59++Kxxx7DkSNHsHDhQjz00EN46qmnAFRcU5aWloYJEyagd+/eSEhIwKlTp5CRkYEmTZpU++3W0aNH0adPHyQkJKBVq1ZwcXHBN998g/Pnz2Po0KH6fh06dMCiRYswc+ZMNG3aFIGBgejdu3eV7Tk5OWHRokUYMGAAoqOjMWrUKISEhODw4cM4dOgQ1q9fb/kPioiISCWio6MxbNgwLFy4EEVFRejSpQt++OEHo9+EpaWl4T//+Q+6du2Kp59+GlqtFu+99x5at26N7Oxsfb8mTZpg5syZmDZtmv6xN97e3jh58iS++eYbjB07FikpKXX4LskRsSAjaaWlpaFhw4Z47733MHnyZAQEBGDs2LGYPXu2wUW648ePhxAC77zzDlJSUtCuXTusXr0azz33HNzd3U1uPzw8HMOGDcMPP/yATz75BC4uLmjRogVWrlyJJ554Qt9v+vTp+O233/Dmm2/i2rVr6Nmzp9GCDADi4+OxadMmpKen45133oFOp0OTJk0wZswYy30wREREKrVkyRI0bNgQmZmZWLVqFXr37o1///vfCA8PN+jXoUMHfP/990hJScGrr76K8PBwzJgxAzk5OTh8+LBB36lTp6J58+aYM2cO0tPTAVRk/KOPPorHHnuszt4bOS5FcC4UURU6nQ4NGzbE4MGD8eGHH9p6OERERGQBgwYNwqFDh3Ds2DFbD4VIj9eQkcMrLS2tMm98+fLlKCwsRGxsrG0GRURERLVy8+ZNg5+PHTuGtWvXMtvJ7vAbMnJ4WVlZmDx5Mv7617+ifv362Lt3Lz7++GO0bNkSe/bsgZubm62HSERERGYKCQlBYmKi/vmjixYtwq1bt7Bv3z6Tzx0jsgVeQ0YOr1GjRggPD8f8+fNRWFiIgIAAjBgxAq+//jqLMSIiIpXq27cvPvvsMxQUFECj0SAmJgazZ89mMUZ2h9+QERERERER2QivISMiqgNbtmzBgAEDEBoaCkVRsGrVKoP1iYmJUBTFYOnbt69Bn8LCQgwfPhw+Pj7w8/NDUlISSkpK6vBdEBERycNespkFGRFRHbh+/TratWuHBQsWmOzTt29f5Ofn65fPPvvMYP3w4cNx6NAhbNiwAWvWrMGWLVswduxYaw+diIhISvaSzbyGTKV0Oh3OnTsHb2/vah9eTCQjIQSuXbuG0NBQODlZ/rxSaWkpbt++XW0fNze3ap9Td6d+/fqhX79+1fbRaDQIDg42ui4nJwfr1q3Dzz//jI4dOwIA3n33XfTv3x9vv/02QkNDazwWIrIOZjM5Omvms8zZzIJMpc6dO1flIYhEjiYvLw9hYWEW3WZpaSkiI7xQcEFbbb/g4GDs37/f4MCv0Wig0Wjued9ZWVkIDAyEv78/evfujZkzZ6J+/foAgB07dsDPz09/wAeAuLg4ODk5YefOnXj88cfveb9EZBnMZqIKls5n2bOZBZlKeXt7AwDC33sBTh73/peM7l1k0gFbD8FhlaMM27B
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.375519</td>\n",
" <td>0.373444</td>\n",
" <td>[0.5460030165912518, 0.0]</td>\n",
" <td>[0.5438066465256798, 0.0]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test \\\n",
"logistic 1.0 1.0 1.0 1.0 \n",
"ridge 1.0 1.0 1.0 1.0 \n",
"decision_tree 1.0 1.0 1.0 1.0 \n",
"knn 1.0 1.0 1.0 1.0 \n",
"naive_bayes 1.0 1.0 1.0 1.0 \n",
"gradient_boosting 1.0 1.0 1.0 1.0 \n",
"random_forest 1.0 1.0 1.0 1.0 \n",
"mlp 0.0 0.0 0.0 0.0 \n",
"\n",
" Accuracy_train Accuracy_test F1_train \\\n",
"logistic 1.000000 1.000000 [1.0, 1.0] \n",
"ridge 1.000000 1.000000 [1.0, 1.0] \n",
"decision_tree 1.000000 1.000000 [1.0, 1.0] \n",
"knn 1.000000 1.000000 [1.0, 1.0] \n",
"naive_bayes 1.000000 1.000000 [1.0, 1.0] \n",
"gradient_boosting 1.000000 1.000000 [1.0, 1.0] \n",
"random_forest 1.000000 1.000000 [1.0, 1.0] \n",
"mlp 0.375519 0.373444 [0.5460030165912518, 0.0] \n",
"\n",
" F1_test \n",
"logistic [1.0, 1.0] \n",
"ridge [1.0, 1.0] \n",
"decision_tree [1.0, 1.0] \n",
"knn [1.0, 1.0] \n",
"naive_bayes [1.0, 1.0] \n",
"gradient_boosting [1.0, 1.0] \n",
"random_forest [1.0, 1.0] \n",
"mlp [0.5438066465256798, 0.0] "
]
},
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>0.373444</td>\n",
" <td>[0.5438066465256798, 0.0]</td>\n",
" <td>0.068065</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test \\\n",
"logistic 1.000000 [1.0, 1.0] 1.000000 \n",
"ridge 1.000000 [1.0, 1.0] 1.000000 \n",
"decision_tree 1.000000 [1.0, 1.0] 1.000000 \n",
"knn 1.000000 [1.0, 1.0] 1.000000 \n",
"naive_bayes 1.000000 [1.0, 1.0] 1.000000 \n",
"gradient_boosting 1.000000 [1.0, 1.0] 1.000000 \n",
"random_forest 1.000000 [1.0, 1.0] 1.000000 \n",
"mlp 0.373444 [0.5438066465256798, 0.0] 0.068065 \n",
"\n",
" Cohen_kappa_test MCC_test \n",
"logistic 1.0 1.0 \n",
"ridge 1.0 1.0 \n",
"decision_tree 1.0 1.0 \n",
"knn 1.0 1.0 \n",
"naive_bayes 1.0 1.0 \n",
"gradient_boosting 1.0 1.0 \n",
"random_forest 1.0 1.0 \n",
"mlp 0.0 0.0 "
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучшая модель"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Находим ошибки"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Predicted</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Education Level, Predicted, Institution Type, Gender, Age, Device, IT Student, Location, Financial Condition, Internet Type, Network Type, Flexibility Level, Access Difficulty]\n",
"Index: []"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"Access Difficulty\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student Location \\\n",
"450 School Private Female 11 Mobile No Town \n",
"\n",
" Financial Condition Internet Type Network Type Flexibility Level \\\n",
"450 Poor Mobile Data 4G 1 \n",
"\n",
" Access Difficulty \n",
"450 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" <th>Institution Type_Public</th>\n",
" <th>Device_Mobile</th>\n",
" <th>Device_Tab</th>\n",
" <th>Location_Town</th>\n",
" <th>Financial Condition_Poor</th>\n",
" <th>Financial Condition_Rich</th>\n",
" <th>Internet Type_Wifi</th>\n",
" <th>Network Type_3G</th>\n",
" <th>Network Type_4G</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
"450 0.775454 0.0 1.0 0.0 \n",
"\n",
" Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
"450 1.0 1.0 0.0 \n",
"\n",
" Internet Type_Wifi Network Type_3G Network Type_4G \n",
"450 0.0 0.0 1.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [0.00310819 0.99689181])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Создаем гиперпараметры методом поиска по сетке. "
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 2,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_model_type = 'random_forest'\n",
"random_state = 9\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=2,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели и сама оценка данных"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
"Name \n",
"Old 1.0 1.0 1.0 1.0 1.0 \n",
"New 1.0 1.0 1.0 1.0 1.0 \n",
"\n",
" Accuracy_test F1_train F1_test \n",
"Name \n",
"Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
"New 1.0 1.0 1.0 "
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
"\n",
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
"Name \n",
"Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
"New 1.0 1.0 1.0 1.0 1.0"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAGsCAYAAABq7AJ3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQDklEQVR4nO3dfVxUdd7/8ffhHhUGMWEgUTHvy9Q0TStFo7Ta0nQvq8tKDPVXqWVmZY/ydr2prm5MLa0syU232q0svVq71EKtzFLDbtZMzZIUtDJAdFFg5veHy9QEGIc5OMyc1/PxOI9tzpz5znfYkTef8/2e7zHcbrdbAAAAAABbCvF3BwAAAAAA/kNRCAAAAAA2RlEIAAAAADZGUQgAAAAANkZRCAAAAAA2RlEIAAAAADZGUQgAAAAANkZRCAAAAAA2FubvDgAAAktJSYlOnjxpWXsRERGKioqyrD0AAMwg1ygKAQAmlJSUKLVFI+UfLresTafTqX379gVcgAIAAh+5dgpFIQCgxk6ePKn8w+Xat62FYmN8vwKh6KhLqd2+18mTJwMqPAEAwYFcO4WiEABgWmxMiCXhCQBAfWD3XKMoBACYVu52qdxtTTsAAPib3XONohAAYJpLbrnke3pa0QYAAL6ye67Zd4wUAAAAAMBIIQDAPJdcsmKCjDWtAADgG7vnGkUhAMC0crdb5W7fp8hY0QYAAL6ye64xfRQAAAAAbIyRQgCAaXa/IB8AEFzsnmsUhQAA01xyq9zG4QkACC52zzWmjwIAAACAjTFSCAAwze7TbAAAwcXuucZIIQAAAADYGCOFAADT7L50NwAguNg91ygKAQCmuf6zWdEOAAD+ZvdcY/ooAAAAANgYI4UAANPKLVq624o2AADwld1zjaIQAGBaufvUZkU7AAD4m91zjemjAAAAAGBjjBQCAEyz+wX5AIDgYvdcoygEAJjmkqFyGZa0AwCAv9k915g+CgAAAAA2xkghAMA0l/vUZkU7AAD4m91zjaIQAGBauUXTbKxoAwAAX9k915g+CgAIGBs3btQ111yj5ORkGYahlStXVnvsbbfdJsMwNG/ePK/9R44c0fDhwxUbG6u4uDhlZmaquLi4bjsOAEAV6kuuURQCAEyrOKNqxWbGsWPH1LlzZz399NOnPe7NN9/Uxx9/rOTk5ErPDR8+XF999ZXWrl2r1atXa+PGjRozZoypfgAAgovdc43powCAgHHllVfqyiuvPO0xBw4c0Pjx4/Xuu+/q6quv9npu586dWrNmjT799FN1795dkrRgwQJdddVVeuyxx6oMWwAA6kp9yTVGCgEAprnchmWbJBUVFXltJ06cqF2/XC7dfPPNuvfee3XuuedWen7z5s2Ki4vzBKckpaenKyQkRFu2bKndDwMAEPDsnmsUhQAA06yeZpOSkiKHw+HZ5s6dW6t+PfLIIwoLC9Odd95Z5fP5+flKSEjw2hcWFqb4+Hjl5+fX6j0BAIHP7rnG9FEAgN/l5uYqNjbW8zgyMtJ0G9u2bdNTTz2l7du3yzACc/U3AEBwCLRcY6QQAGBauUIs2yQpNjbWa6tNeG7atEmHDx9W8+bNFRYWprCwMH3//fe655571LJlS0mS0+nU4cOHvV5XVlamI0eOyOl0+vxzAQAEJrvnGiOFAADT3L+5bsLXdqxy8803Kz093WvfgAEDdPPNN2vkyJGSpF69eqmgoEDbtm1Tt27dJEnvvfeeXC6XevbsaVlfAACBxe65RlEIAAgYxcXF2rNnj+fxvn37lJOTo/j4eDVv3lxNmjTxOj48PFxOp1Pt2rWTJHXo0EEDBw7U6NGjtXjxYpWWlmrcuHG64YYbWHkUAHDG1ZdcoygEAJhWm3sxVdeOGVu3blW/fv08jydOnChJGjFihLKysmrUxvLlyzVu3DhddtllCgkJ0dChQzV//nxT/QAABBe755rhdrvdpl4BALCtoqIiORwO/fPzVDWM8f2y9GNHXbry/H0qLCz0uiAfAIAzgVw7hYVmAAAAAMDGmD4KADDNJUMuC84rusRkFQCA/9k91xgpBAAAAAAbY6QQAGCavy7IBwCgLtg91ygKAQCmlbtDVO72fbJJOWudAQDqAbvnGtNHAQAAAMDGGCkEAJh26oJ836fIWNEGAAC+snuuURQCAExzKUTlNl6lDQAQXOyea0wfBQAAAAAbY6QQAGCa3S/IBwAEF7vnGkUhAMA0l0JsfZNfAEBwsXuuMX0UAAAAAGyMkUIAgGnlbkPlbgtu8mtBGwAA+MruucZIIQAAAADYGCOFAADTyi1aurs8QK+9AAAEF7vnGkUhAMA0lztELgtWaXMF6CptAIDgYvdcY/ooAAAAANgYI4UAANPsPs0GABBc7J5rFIUAANNcsmaFNZfvXQEAwGd2zzWmjwIAAACAjTFSCAAwzaUQuSw4r2hFGwAA+MruuUZRCAAwrdwdonILVmmzog0AAHxl91wLzF4DAAAAACzBSCEAwDSXDLlkxQX5vrcBAICv7J5rFIUAANPsPs0GABBc7J5rgdlrAAAAAIAlGCkEAJhm3U1+OTcJAPA/u+daYPYaAAAAAGAJRgoDlMvl0sGDBxUTEyPDCMwLWgGcWW63W0ePHlVycrJCQnw7J+hyG3K5Lbgg34I2EBzINQBmkWvWoSgMUAcPHlRKSoq/uwEgAOXm5qpZs2Y+teGyaJpNoN7kF9Yj1wDUFrnmO4rCABUTEyNJSll4r0KiI/3cG9Q3qZlf+LsLqIfKVKoP9I7n9wdQn5BrOB1yDVUh16xDURigKqbWhERHKqRBlJ97g/omzAj3dxdQH7lP/Y8VU/Nc7hC5LFh224o2EBzINZwOuYYqkWuWoSgEAJhWLkPlFtyg14o2AADwld1zLTBLWQAAAACAJRgpBACYZvdpNgCA4GL3XKMoBACYVi5rpsiU+94VAAB8ZvdcC8xSFgAAAABgCUYKAQCm2X2aDQAguNg91wKz1wAAAAAASzBSCAAwrdwdonILzoZa0QYAAL6ye65RFAIATHPLkMuCC/LdAXo/JwBAcLF7rgVmKQsAAAAAsARFIQDAtIppNlZsZmzcuFHXXHONkpOTZRiGVq5c6XmutLRU999/vzp16qSGDRsqOTlZt9xyiw4ePOjVxpEjRzR8+HDFxsYqLi5OmZmZKi4utuLHAgAIUHbPNYpCAIBpLrdh2WbGsWPH1LlzZz399NOVnjt+/Li2b9+uKVOmaPv27XrjjTe0a9cuXXvttV7HDR8+XF999ZXWrl2r1atXa+PGjRozZoxPPw8AQGCze65xTSEAIGBceeWVuvLKK6t8zuFwaO3atV77Fi5cqB49emj//v1q3ry5du7cqTVr1ujTTz9V9+7dJUkLFizQVVddpccee0zJycl1/hkAAKhQX3KNkUIAgGnlCrFsk6SioiKv7cSJE5b0s7CwUIZhKC4uTpK0efNmxcXFeYJTktLT0xUSEqItW7ZY8p4AgMBj91yjKAQAmGb1NJuUlBQ5HA7PNnfuXJ/7WFJSovvvv1833nijYmNjJUn5+flKSEjwOi4sLEzx8fHKz8/3+T0BAIHJ7rnG9FEAgN/l5uZ6Ak6SIiMjfWqvtLRUw4YNk9vt1qJFi3ztHgAApgRarlEUAgBMcylELgsmm1S0ERsb6xWevqgIzu+//17vvfeeV7tOp1OHDx/2Or6srExHjhyR0+m05P0BAIHH7rnG9FEAgGnlbsOyzUoVwbl7926tW7dOTZo08Xq+V69eKigo0LZt2zz73nvvPblcLvXs2dPSvgAAAofdc42RQgBAwCguLtaePXs8j/ft26ecnBzFx8crKSlJf/7zn7V9+3atXr1a5eXlnusp4uPjFRERoQ4dOmjgwIEaPXq0Fi9erNLSUo0bN0433HADK48CAM64+pJrFIUAANNqcy+m6toxY+vWrerXr5/n8cSJEyVJI0aM0PTp0/X2229Lkrp06eL1uvfff19paWmSpOXLl2vcuHG67LLLFBISoqFDh2r+/Pm1/xAAgIBn91yjKAQABIy0tDS53e5qnz/dcxXi4+O1YsUKK7sFAECt1JdcoygEAJjmdof
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Модель идеально классифицировала объекты, которые относятся к \"High difficulty\" и \"Low difficulty\"."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}