AIM-PIbd-31-Makarov-DV/lab_4/lab4.ipynb

3933 lines
531 KiB
Plaintext
Raw Normal View History

2024-11-15 00:44:23 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Лабораторная 4\n",
"\n",
"Датасет: Информация об онлайн обучении учеников\n",
"\n",
2024-11-15 16:47:21 +04:00
"## Бизнес-цель 1: \n",
"Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения."
2024-11-15 00:44:23 +04:00
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 2,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['Education Level', 'Institution Type', 'Gender', 'Age', 'Device',\n",
" 'IT Student', 'Location', 'Financial Condition', 'Internet Type',\n",
" 'Network Type', 'Flexibility Level'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from typing import Tuple\n",
"from pandas import DataFrame\n",
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, metrics, set_config\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n",
"print(df.columns)\n",
"\n",
"map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n",
"\n",
"df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Предварительно создадим колонку для работы с ней (ключевой фактор)"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 3,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"fincond_mapping = {'Poor': 2, 'Mid': 1, 'Rich': 0}\n",
"internet_type_mapping = {'Mobile Data': 1, 'Wifi': 0}\n",
"device_mapping = {'Mobile': 1, 'Computer': 0}\n",
"network_type = {'2G': 2, '3G': 1, '4G': 0}\n",
"\n",
"df['Financial Score'] = df['Financial Condition'].map(fincond_mapping)\n",
"df['Internet Score'] = df['Internet Type'].map(internet_type_mapping)\n",
"df['Device Score'] = df['Device'].map(device_mapping)\n",
"df['Network Score'] = df['Network Type'].map(network_type)\n",
"\n",
"df['Access Difficulty Score'] = df['Financial Score'] + df['Internet Score'] + df['Device Score'] + df['Network Score']\n",
"\n",
"df['Access Difficulty'] = (df['Access Difficulty Score'] >= 3).astype(int)\n",
"df.drop(columns=['Financial Score', 'Device Score', 'Internet Score', 'Network Score', 'Access Difficulty Score'], inplace=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем выборки"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 4,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>University</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Computer</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>College</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>University</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>27</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"649 School Public Male 18 Mobile No \n",
"637 School Private Female 9 Mobile No \n",
"68 School Public Female 11 Mobile No \n",
"276 University Private Female 18 Mobile Yes \n",
"547 School Public Male 11 Mobile No \n",
"... ... ... ... ... ... ... \n",
"1097 University Private Male 23 Mobile Yes \n",
"854 School Public Female 18 Mobile No \n",
"756 University Public Male 18 Computer No \n",
"133 College Public Male 18 Mobile No \n",
"53 University Public Male 27 Mobile Yes \n",
"\n",
" Location Financial Condition Internet Type Network Type \\\n",
"649 Town Mid Wifi 4G \n",
"637 Town Mid Mobile Data 4G \n",
"68 Town Mid Wifi 4G \n",
"276 Town Mid Mobile Data 3G \n",
"547 Town Mid Wifi 4G \n",
"... ... ... ... ... \n",
"1097 Town Rich Wifi 4G \n",
"854 Town Mid Mobile Data 4G \n",
"756 Town Mid Wifi 3G \n",
"133 Town Poor Mobile Data 4G \n",
"53 Rural Poor Mobile Data 4G \n",
"\n",
" Flexibility Level Access Difficulty \n",
"649 1 0 \n",
"637 1 1 \n",
"68 0 0 \n",
"276 0 1 \n",
"547 1 0 \n",
"... ... ... \n",
"1097 0 0 \n",
"854 0 1 \n",
"756 1 0 \n",
"133 0 1 \n",
"53 1 1 \n",
"\n",
"[964 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty\n",
"649 0\n",
"637 1\n",
"68 0\n",
"276 1\n",
"547 0\n",
"... ...\n",
"1097 0\n",
"854 1\n",
"756 0\n",
"133 1\n",
"53 1\n",
"\n",
"[964 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>265</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>316</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Tab</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>907</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1042</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>936</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Tab</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>722</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1075</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 12 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"265 School Private Female 9 Mobile No \n",
"358 School Private Female 10 Mobile No \n",
"316 University Private Male 23 Tab No \n",
"907 School Private Female 9 Mobile No \n",
"1042 University Private Male 23 Mobile No \n",
"... ... ... ... ... ... ... \n",
"421 School Private Female 10 Mobile No \n",
"936 University Private Male 23 Tab No \n",
"722 University Private Male 23 Mobile Yes \n",
"1075 University Private Male 23 Computer Yes \n",
"577 University Private Male 23 Mobile Yes \n",
"\n",
" Location Financial Condition Internet Type Network Type \\\n",
"265 Town Poor Wifi 4G \n",
"358 Town Mid Mobile Data 3G \n",
"316 Town Mid Wifi 4G \n",
"907 Town Poor Mobile Data 4G \n",
"1042 Town Mid Mobile Data 3G \n",
"... ... ... ... ... \n",
"421 Town Mid Mobile Data 3G \n",
"936 Town Rich Wifi 4G \n",
"722 Rural Poor Mobile Data 3G \n",
"1075 Town Mid Wifi 4G \n",
"577 Town Mid Wifi 4G \n",
"\n",
" Flexibility Level Access Difficulty \n",
"265 1 1 \n",
"358 1 1 \n",
"316 1 0 \n",
"907 1 1 \n",
"1042 1 1 \n",
"... ... ... \n",
"421 1 1 \n",
"936 2 0 \n",
"722 1 1 \n",
"1075 0 0 \n",
"577 0 0 \n",
"\n",
"[241 rows x 12 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>265</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>358</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>316</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>907</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1042</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>421</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>936</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>722</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1075</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>577</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty\n",
"265 1\n",
"358 1\n",
"316 0\n",
"907 1\n",
"1042 1\n",
"... ...\n",
"421 1\n",
"936 0\n",
"722 1\n",
"1075 0\n",
"577 0\n",
"\n",
"[241 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ]\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" \n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"Access Difficulty\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 5,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Пропущенные значения по столбцам:\n",
"Education Level 0\n",
"Institution Type 0\n",
"Gender 0\n",
"Age 0\n",
"Device 0\n",
"IT Student 0\n",
"Location 0\n",
"Financial Condition 0\n",
"Internet Type 0\n",
"Network Type 0\n",
"Flexibility Level 0\n",
"Access Difficulty 0\n",
"dtype: int64\n",
"\n",
"Статистический обзор данных:\n",
" Age Flexibility Level Access Difficulty\n",
"count 1205.000000 1205.000000 1205.000000\n",
"mean 17.065560 0.684647 0.624896\n",
"std 5.830369 0.618221 0.484351\n",
"min 9.000000 0.000000 0.000000\n",
"25% 11.000000 0.000000 0.000000\n",
"50% 18.000000 1.000000 1.000000\n",
"75% 23.000000 1.000000 1.000000\n",
"max 27.000000 2.000000 1.000000\n"
]
}
],
"source": [
"null_values = df.isnull().sum()\n",
"print(\"Пропущенные значения по столбцам:\")\n",
"print(null_values)\n",
"\n",
"stat_summary = df.describe()\n",
"print(\"\\nСтатистический обзор данных:\")\n",
"print(stat_summary)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем конвеер для классификации данных и проверка конвеера"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 6,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" <th>Institution Type_Public</th>\n",
" <th>Device_Mobile</th>\n",
" <th>Device_Tab</th>\n",
" <th>Location_Town</th>\n",
" <th>Financial Condition_Poor</th>\n",
" <th>Financial Condition_Rich</th>\n",
" <th>Internet Type_Wifi</th>\n",
" <th>Network Type_3G</th>\n",
" <th>Network Type_4G</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>649</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>637</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>68</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>276</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>547</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1097</th>\n",
" <td>-1.289567</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>854</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>756</th>\n",
" <td>-1.289567</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>133</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>53</th>\n",
" <td>0.775454</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
"649 -1.289567 1.0 1.0 0.0 \n",
"637 0.775454 0.0 1.0 0.0 \n",
"68 -1.289567 1.0 1.0 0.0 \n",
"276 0.775454 0.0 1.0 0.0 \n",
"547 -1.289567 1.0 1.0 0.0 \n",
"... ... ... ... ... \n",
"1097 -1.289567 0.0 1.0 0.0 \n",
"854 0.775454 1.0 1.0 0.0 \n",
"756 -1.289567 1.0 0.0 0.0 \n",
"133 0.775454 1.0 1.0 0.0 \n",
"53 0.775454 1.0 1.0 0.0 \n",
"\n",
" Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
"649 1.0 0.0 0.0 \n",
"637 1.0 0.0 0.0 \n",
"68 1.0 0.0 0.0 \n",
"276 1.0 0.0 0.0 \n",
"547 1.0 0.0 0.0 \n",
"... ... ... ... \n",
"1097 1.0 0.0 1.0 \n",
"854 1.0 0.0 0.0 \n",
"756 1.0 0.0 0.0 \n",
"133 1.0 1.0 0.0 \n",
"53 0.0 1.0 0.0 \n",
"\n",
" Internet Type_Wifi Network Type_3G Network Type_4G \n",
"649 1.0 0.0 1.0 \n",
"637 0.0 0.0 1.0 \n",
"68 1.0 0.0 1.0 \n",
"276 0.0 1.0 0.0 \n",
"547 1.0 0.0 1.0 \n",
"... ... ... ... \n",
"1097 1.0 0.0 1.0 \n",
"854 0.0 0.0 1.0 \n",
"756 1.0 1.0 0.0 \n",
"133 0.0 0.0 1.0 \n",
"53 0.0 0.0 1.0 \n",
"\n",
"[964 rows x 10 columns]"
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 6,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"columns_to_drop = ['Age', 'Education Level', 'Gender', 'IT Student', 'Flexibility Level']\n",
"num_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype != \"object\"\n",
"]\n",
"cat_columns = [\n",
" column\n",
" for column in df.columns\n",
" if column not in columns_to_drop and df[column].dtype == \"object\"\n",
"]\n",
"\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num, num_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")\n",
"\n",
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формируем набор моделей"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 7,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression()},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=11, class_weight=\"balanced\", random_state=9\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=9,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модели и тестируем их"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 8,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
]
}
],
"source": [
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
2024-11-15 16:47:21 +04:00
"source": [
"Матрица неточностей"
]
2024-11-15 00:44:23 +04:00
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 9,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2QAAAQ9CAYAAAA2zo55AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVxU5f4H8M9hG5AdlS0IcUPNBdM0XFEp1HtNs3tRs6so6M9Kc4lKbyXgVduvS7m0KVpYeVvMvKZXS9yuaS6YGm6oiQpuKIiKwszz+4PL1MgMMjLDzHnm876v87rynGfOeWbQ8+l75jnnKEIIASIiIiIiIqpzTrYeABERERERkaNiQUZERERERGQjLMiIiIiIiIhshAUZERERERGRjbAgIyIiIiIishEWZERERERERDbCgoyIiIiIiMhGWJARERERERHZCAsyIiIiIiIiG2FBRtLIyMiAoig4deqUVbZ/6tQpKIqCjIwMi2wvKysLiqIgKyvLItsjIiKSSVpaGhRFqVFfRVGQlpZm3QERWQkLMiIrW7hwocWKOCIiIiKSi4utB0CkFhEREbh58yZcXV3Net3ChQvRoEEDJCYmGrT36NEDN2/ehJubmwVHSUREJIdXXnkFU6dOtfUwiKyOBRlRDSmKAnd3d4ttz8nJyaLbIyIiksX169fh6ekJFxf+pyrJj1MWSWoLFy7EAw88AI1Gg9DQUDz77LO4evVqlX4LFixA48aN4eHhgU6dOmHr1q2IjY1FbGysvo+xa8gKCgowatQohIWFQaPRICQkBAMHDtRfx9aoUSMcOnQImzdvhqIoUBRFv01T15Dt3LkT/fv3h7+/Pzw9PdG2bVvMmzfPsh8MERGRnai8VuzXX3/Fk08+CX9/f3Tr1s3oNWS3bt3C5MmT0bBhQ3h7e+Oxxx7DmTNnjG43KysLHTt2hLu7O5o0aYL333/f5HVpn376KTp06AAPDw8EBARg6NChyMvLs8r7JboTTzuQtNLS0pCeno64uDg8/fTTOHLkCBYtWoSff/4Z27dv1089XLRoEcaPH4/u3btj8uTJOHXqFAYNGgR/f3+EhYVVu48nnngChw4dwoQJE9CoUSNcuHABGzZswOnTp9GoUSPMnTsXEyZMgJeXF15++WUAQFBQkMntbdiwAX/+858REhKCiRMnIjg4GDk5OVizZg0mTpxouQ+HiIjIzvz1r39Fs2bNMHv2bAghcOHChSp9kpOT8emnn+LJJ59Ely5d8OOPP+JPf/pTlX779u1D3759ERISgvT0dGi1WsyYMQMNGzas0nfWrFl49dVXkZCQgOTkZFy8eBHvvvsuevTogX379sHPz88ab5fod4JIEkuXLhUAxMmTJ8WFCxeEm5ubePTRR4VWq9X3ee+99wQAsWTJEiGEELdu3RL169cXDz30kCgrK9P3y8jIEABEz5499W0nT54UAMTSpUuFEEJcuXJFABBvvfVWteN64IEHDLZTadOmTQKA2LRpkxBCiPLychEZGSkiIiLElStXDPrqdLqafxBEREQqkpqaKgCIYcOGGW2vlJ2dLQCIZ555xqDfk08+KQCI1NRUfduAAQNEvXr1xNmzZ/Vtx44dEy4uLgbbPHXqlHB2dhazZs0y2OaBAweEi4tLlXYia+CURZLSxo0bcfv2bUyaNAlOTr//NR8zZgx8fHzw73//GwCwe/duXL58GWPGjDGYpz58+HD4+/tXuw8PDw+4ubkhKysLV65cqfWY9+3bh5MnT2LSpElVzsbV9La/REREajVu3Lhq169duxYA8Nxzzxm0T5o0yeBnrVaLjRs3YtCgQQgNDdW3N23aFP369TPo+/XXX0On0yEhIQGXLl3SL8HBwWjWrBk2bdpUi3dEVDOcskhS+u233wAAUVFRBu1ubm5o3Lixfn3l/zdt2tSgn4uLCxo1alTtPjQaDd544w08//zzCAoKwsMPP4w///nPGDFiBIKDg80ec25uLgCgdevWZr+WiIhI7SIjI6td/9tvv8HJyQlNmjQxaL8z6y9cuICbN29WyXagat4fO3YMQgg0a9bM6D7NvbMy0b1gQUZUC5MmTcKAAQOwatUqrF+/Hq+++ipee+01/Pjjj2jfvr2th0dERKQaHh4edb5PnU4HRVHw/fffw9nZucp6Ly+vOh8TOR5OWSQpRUREAACOHDli0H779m2cPHlSv77y/48fP27Qr7y8XH+nxLtp0qQJnn/+efznP//BwYMHcfv2bbzzzjv69TWdblh5xu/gwYM16k9ERORIIiIioNPp9DNKKt2Z9YGBgXB3d6+S7UDVvG/SpAmEEIiMjERcXFyV5eGHH7b8GyG6AwsyklJcXBzc3Nwwf/58CCH07R9//DGKior0d2Tq2LEj6tevjw8//BDl5eX6fpmZmXe9LuzGjRsoLS01aGvSpAm8vb1x69YtfZunp6fRW+3f6cEHH0RkZCTmzp1bpf8f3wMREZEjqrz+a/78+Qbtc+fONfjZ2dkZcXFxWLVqFc6dO6dvP378OL7//nuDvoMHD4azszPS09OrZK0QApcvX7bgOyAyjlMWSUoNGzbEtGnTkJ6ejr59++Kxxx7DkSNHsHDhQjz00EN46qmnAFRcU5aWloYJEyagd+/eSEhIwKlTp5CRkYEmTZpU++3W0aNH0adPHyQkJKBVq1ZwcXHBN998g/Pnz2Po0KH6fh06dMCiRYswc+ZMNG3aFIGBgejdu3eV7Tk5OWHRokUYMGAAoqOjMWrUKISEhODw4cM4dOgQ1q9fb/kPioiISCWio6MxbNgwLFy4EEVFRejSpQt++OEHo9+EpaWl4T//+Q+6du2Kp59+GlqtFu+99x5at26N7Oxsfb8mTZpg5syZmDZtmv6xN97e3jh58iS++eYbjB07FikpKXX4LskRsSAjaaWlpaFhw4Z47733MHnyZAQEBGDs2LGYPXu2wUW648ePhxAC77zzDlJSUtCuXTusXr0azz33HNzd3U1uPzw8HMOGDcMPP/yATz75BC4uLmjRogVWrlyJJ554Qt9v+vTp+O233/Dmm2/i2rVr6Nmzp9GCDADi4+OxadMmpKen45133oFOp0OTJk0wZswYy30wREREKrVkyRI0bNgQmZmZWLVqFXr37o1///vfCA8PN+jXoUMHfP/990hJScGrr76K8PBwzJgxAzk5OTh8+LBB36lTp6J58+aYM2cO0tPTAVRk/KOPPorHHnuszt4bOS5FcC4UURU6nQ4NGzbE4MGD8eGHH9p6OERERGQBgwYNwqFDh3Ds2DFbD4VIj9eQkcMrLS2tMm98+fLlKCwsRGxsrG0GRURERLVy8+ZNg5+PHTuGtWvXMtvJ7vAbMnJ4WVlZmDx5Mv7617+ifv362Lt3Lz7++GO0bNkSe/bsgZubm62HSERERGYKCQlBYmKi/vmjixYtwq1bt7Bv3z6Tzx0jsgVeQ0YOr1GjRggPD8f8+fNRWFiIgIAAjBgxAq+//jqLMSIiIpXq27cvPvvsMxQUFECj0SAmJgazZ89mMUZ2h9+QERERERER2QivISMiqgNbtmzBgAEDEBoaCkVRsGrVKoP1iYmJUBTFYOnbt69Bn8LCQgwfPhw+Pj7w8/NDUlISSkpK6vBdEBERycNespkFGRFRHbh+/TratWuHBQsWmOzTt29f5Ofn65fPPvvMYP3w4cNx6NAhbNiwAWvWrMGWLVswduxYaw+diIhISvaSzbyGTKV0Oh3OnTsHb2/vah9eTCQjIQSuXbuG0NBQODlZ/rxSaWkpbt++XW0fNze3ap9Td6d+/fqhX79+1fbRaDQIDg42ui4nJwfr1q3Dzz//jI4dOwIA3n33XfTv3x9vv/02QkNDazwWIrIOZjM5Omvms8zZzIJMpc6dO1flIYhEjiYvLw9hYWEW3WZpaSkiI7xQcEFbbb/g4GDs37/f4MCv0Wig0Wjued9ZWVkIDAyEv78/evfujZkzZ6J+/foAgB07dsDPz09/wAeAuLg4ODk5YefOnXj88cfveb9EZBnMZqIKls5n2bOZBZlKeXt7AwDC33sBTh73/peM7l1k0gFbD8FhlaMM27B
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 10,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.375519</td>\n",
" <td>0.373444</td>\n",
" <td>[0.5460030165912518, 0.0]</td>\n",
" <td>[0.5438066465256798, 0.0]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test \\\n",
"logistic 1.0 1.0 1.0 1.0 \n",
"ridge 1.0 1.0 1.0 1.0 \n",
"decision_tree 1.0 1.0 1.0 1.0 \n",
"knn 1.0 1.0 1.0 1.0 \n",
"naive_bayes 1.0 1.0 1.0 1.0 \n",
"gradient_boosting 1.0 1.0 1.0 1.0 \n",
"random_forest 1.0 1.0 1.0 1.0 \n",
"mlp 0.0 0.0 0.0 0.0 \n",
"\n",
" Accuracy_train Accuracy_test F1_train \\\n",
"logistic 1.000000 1.000000 [1.0, 1.0] \n",
"ridge 1.000000 1.000000 [1.0, 1.0] \n",
"decision_tree 1.000000 1.000000 [1.0, 1.0] \n",
"knn 1.000000 1.000000 [1.0, 1.0] \n",
"naive_bayes 1.000000 1.000000 [1.0, 1.0] \n",
"gradient_boosting 1.000000 1.000000 [1.0, 1.0] \n",
"random_forest 1.000000 1.000000 [1.0, 1.0] \n",
"mlp 0.375519 0.373444 [0.5460030165912518, 0.0] \n",
"\n",
" F1_test \n",
"logistic [1.0, 1.0] \n",
"ridge [1.0, 1.0] \n",
"decision_tree [1.0, 1.0] \n",
"knn [1.0, 1.0] \n",
"naive_bayes [1.0, 1.0] \n",
"gradient_boosting [1.0, 1.0] \n",
"random_forest [1.0, 1.0] \n",
"mlp [0.5438066465256798, 0.0] "
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 10,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 11,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>logistic</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ridge</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>decision_tree</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>knn</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>naive_bayes</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gradient_boosting</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>random_forest</th>\n",
" <td>1.000000</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mlp</th>\n",
" <td>0.373444</td>\n",
" <td>[0.5438066465256798, 0.0]</td>\n",
" <td>0.068065</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test \\\n",
"logistic 1.000000 [1.0, 1.0] 1.000000 \n",
"ridge 1.000000 [1.0, 1.0] 1.000000 \n",
"decision_tree 1.000000 [1.0, 1.0] 1.000000 \n",
"knn 1.000000 [1.0, 1.0] 1.000000 \n",
"naive_bayes 1.000000 [1.0, 1.0] 1.000000 \n",
"gradient_boosting 1.000000 [1.0, 1.0] 1.000000 \n",
"random_forest 1.000000 [1.0, 1.0] 1.000000 \n",
"mlp 0.373444 [0.5438066465256798, 0.0] 0.068065 \n",
"\n",
" Cohen_kappa_test MCC_test \n",
"logistic 1.0 1.0 \n",
"ridge 1.0 1.0 \n",
"decision_tree 1.0 1.0 \n",
"knn 1.0 1.0 \n",
"naive_bayes 1.0 1.0 \n",
"gradient_boosting 1.0 1.0 \n",
"random_forest 1.0 1.0 \n",
"mlp 0.0 0.0 "
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 11,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучшая модель"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 12,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Находим ошибки"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": null,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Predicted</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [Education Level, Predicted, Institution Type, Gender, Age, Device, IT Student, Location, Financial Condition, Internet Type, Network Type, Flexibility Level, Access Difficulty]\n",
"Index: []"
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 13,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
2024-11-15 16:47:21 +04:00
"y_new_pred = class_models[best_model][\"preds\"]\n",
2024-11-15 00:44:23 +04:00
"\n",
2024-11-15 16:47:21 +04:00
"error_index = y_test[y_test[\"Access Difficulty\"] != y_new_pred].index.tolist()\n",
2024-11-15 00:44:23 +04:00
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
2024-11-15 16:47:21 +04:00
"error_predicted = pd.Series(y_new_pred, index=y_test.index).loc[error_index]\n",
2024-11-15 00:44:23 +04:00
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Пример использования модели (конвейера) для предсказания"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 14,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" <th>Flexibility Level</th>\n",
" <th>Access Difficulty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student Location \\\n",
"450 School Private Female 11 Mobile No Town \n",
"\n",
" Financial Condition Internet Type Network Type Flexibility Level \\\n",
"450 Poor Mobile Data 4G 1 \n",
"\n",
" Access Difficulty \n",
"450 1 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Access Difficulty</th>\n",
" <th>Institution Type_Public</th>\n",
" <th>Device_Mobile</th>\n",
" <th>Device_Tab</th>\n",
" <th>Location_Town</th>\n",
" <th>Financial Condition_Poor</th>\n",
" <th>Financial Condition_Rich</th>\n",
" <th>Internet Type_Wifi</th>\n",
" <th>Network Type_3G</th>\n",
" <th>Network Type_4G</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>450</th>\n",
" <td>0.775454</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n",
"450 0.775454 0.0 1.0 0.0 \n",
"\n",
" Location_Town Financial Condition_Poor Financial Condition_Rich \\\n",
"450 1.0 1.0 0.0 \n",
"\n",
" Internet Type_Wifi Network Type_3G Network Type_4G \n",
"450 0.0 0.0 1.0 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [0.00310819 0.99689181])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 450\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Создаем гиперпараметры методом поиска по сетке. "
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 15,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
2024-11-15 16:47:21 +04:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
2024-11-15 00:44:23 +04:00
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 2,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 15,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_model_type = 'random_forest'\n",
"random_state = 9\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n",
" \"model__max_features\": [\"sqrt\", \"log2\", 2],\n",
" \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n",
" \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": null,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=2,\n",
" max_features=\"sqrt\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Формирование данных для оценки старой и новой версии модели и сама оценка данных"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 17,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Precision_train</th>\n",
" <th>Precision_test</th>\n",
" <th>Recall_train</th>\n",
" <th>Recall_test</th>\n",
" <th>Accuracy_train</th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_train</th>\n",
" <th>F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>[1.0, 1.0]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n",
"Name \n",
"Old 1.0 1.0 1.0 1.0 1.0 \n",
"New 1.0 1.0 1.0 1.0 1.0 \n",
"\n",
" Accuracy_test F1_train F1_test \n",
"Name \n",
"Old 1.0 [1.0, 1.0] [1.0, 1.0] \n",
"New 1.0 1.0 1.0 "
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 17,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")\n",
"\n",
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 18,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Accuracy_test</th>\n",
" <th>F1_test</th>\n",
" <th>ROC_AUC_test</th>\n",
" <th>Cohen_kappa_test</th>\n",
" <th>MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th>Name</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>Old</th>\n",
" <td>1.0</td>\n",
" <td>[1.0, 1.0]</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>New</th>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n",
"Name \n",
"Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n",
"New 1.0 1.0 1.0 1.0 1.0"
]
},
2024-11-15 16:47:21 +04:00
"execution_count": 18,
2024-11-15 00:44:23 +04:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]"
]
},
{
"cell_type": "code",
2024-11-15 16:47:21 +04:00
"execution_count": 19,
2024-11-15 00:44:23 +04:00
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA4UAAAGsCAYAAABq7AJ3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABQDklEQVR4nO3dfVxUdd7/8ffhHhUGMWEgUTHvy9Q0TStFo7Ta0nQvq8tKDPVXqWVmZY/ydr2prm5MLa0syU232q0svVq71EKtzFLDbtZMzZIUtDJAdFFg5veHy9QEGIc5OMyc1/PxOI9tzpz5znfYkTef8/2e7zHcbrdbAAAAAABbCvF3BwAAAAAA/kNRCAAAAAA2RlEIAAAAADZGUQgAAAAANkZRCAAAAAA2RlEIAAAAADZGUQgAAAAANkZRCAAAAAA2FubvDgAAAktJSYlOnjxpWXsRERGKioqyrD0AAMwg1ygKAQAmlJSUKLVFI+UfLresTafTqX379gVcgAIAAh+5dgpFIQCgxk6ePKn8w+Xat62FYmN8vwKh6KhLqd2+18mTJwMqPAEAwYFcO4WiEABgWmxMiCXhCQBAfWD3XKMoBACYVu52qdxtTTsAAPib3XONohAAYJpLbrnke3pa0QYAAL6ye67Zd4wUAAAAAMBIIQDAPJdcsmKCjDWtAADgG7vnGkUhAMC0crdb5W7fp8hY0QYAAL6ye64xfRQAAAAAbIyRQgCAaXa/IB8AEFzsnmsUhQAA01xyq9zG4QkACC52zzWmjwIAAACAjTFSCAAwze7TbAAAwcXuucZIIQAAAADYGCOFAADT7L50NwAguNg91ygKAQCmuf6zWdEOAAD+ZvdcY/ooAAAAANgYI4UAANPKLVq624o2AADwld1zjaIQAGBaufvUZkU7AAD4m91zjemjAAAAAGBjjBQCAEyz+wX5AIDgYvdcoygEAJjmkqFyGZa0AwCAv9k915g+CgAAAAA2xkghAMA0l/vUZkU7AAD4m91zjaIQAGBauUXTbKxoAwAAX9k915g+CgAIGBs3btQ111yj5ORkGYahlStXVnvsbbfdJsMwNG/ePK/9R44c0fDhwxUbG6u4uDhlZmaquLi4bjsOAEAV6kuuURQCAEyrOKNqxWbGsWPH1LlzZz399NOnPe7NN9/Uxx9/rOTk5ErPDR8+XF999ZXWrl2r1atXa+PGjRozZoypfgAAgovdc43powCAgHHllVfqyiuvPO0xBw4c0Pjx4/Xuu+/q6quv9npu586dWrNmjT799FN1795dkrRgwQJdddVVeuyxx6oMWwAA6kp9yTVGCgEAprnchmWbJBUVFXltJ06cqF2/XC7dfPPNuvfee3XuuedWen7z5s2Ki4vzBKckpaenKyQkRFu2bKndDwMAEPDsnmsUhQAA06yeZpOSkiKHw+HZ5s6dW6t+PfLIIwoLC9Odd95Z5fP5+flKSEjw2hcWFqb4+Hjl5+fX6j0BAIHP7rnG9FEAgN/l5uYqNjbW8zgyMtJ0G9u2bdNTTz2l7du3yzACc/U3AEBwCLRcY6QQAGBauUIs2yQpNjbWa6tNeG7atEmHDx9W8+bNFRYWprCwMH3//fe655571LJlS0mS0+nU4cOHvV5XVlamI0eOyOl0+vxzAQAEJrvnGiOFAADT3L+5bsLXdqxy8803Kz093WvfgAEDdPPNN2vkyJGSpF69eqmgoEDbtm1Tt27dJEnvvfeeXC6XevbsaVlfAACBxe65RlEIAAgYxcXF2rNnj+fxvn37lJOTo/j4eDVv3lxNmjTxOj48PFxOp1Pt2rWTJHXo0EEDBw7U6NGjtXjxYpWWlmrcuHG64YYbWHkUAHDG1ZdcoygEAJhWm3sxVdeOGVu3blW/fv08jydOnChJGjFihLKysmrUxvLlyzVu3DhddtllCgkJ0dChQzV//nxT/QAABBe755rhdrvdpl4BALCtoqIiORwO/fPzVDWM8f2y9GNHXbry/H0qLCz0uiAfAIAzgVw7hYVmAAAAAMDGmD4KADDNJUMuC84rusRkFQCA/9k91xgpBAAAAAAbY6QQAGCavy7IBwCgLtg91ygKAQCmlbtDVO72fbJJOWudAQDqAbvnGtNHAQAAAMDGGCkEAJh26oJ836fIWNEGAAC+snuuURQCAExzKUTlNl6lDQAQXOyea0wfBQAAAAAbY6QQAGCa3S/IBwAEF7vnGkUhAMA0l0JsfZNfAEBwsXuuMX0UAAAAAGyMkUIAgGnlbkPlbgtu8mtBGwAA+MruucZIIQAAAADYGCOFAADTyi1aurs8QK+9AAAEF7vnGkUhAMA0lztELgtWaXMF6CptAIDgYvdcY/ooAAAAANgYI4UAANPsPs0GABBc7J5rFIUAANNcsmaFNZfvXQEAwGd2zzWmjwIAAACAjTFSCAAwzaUQuSw4r2hFGwAA+MruuUZRCAAwrdwdonILVmmzog0AAHxl91wLzF4DAAAAACzBSCEAwDSXDLlkxQX5vrcBAICv7J5rFIUAANPsPs0GABBc7J5rgdlrAAAAAIAlGCkEAJhm3U1+OTcJAPA/u+daYPYaAAAAAGAJRgoDlMvl0sGDBxUTEyPDCMwLWgGcWW63W0ePHlVycrJCQnw7J+hyG3K5Lbgg34I2EBzINQBmkWvWoSgMUAcPHlRKSoq/uwEgAOXm5qpZs2Y+teGyaJpNoN7kF9Yj1wDUFrnmO4rCABUTEyNJSll4r0KiI/3cG9Q3qZlf+LsLqIfKVKoP9I7n9wdQn5BrOB1yDVUh16xDURigKqbWhERHKqRBlJ97g/omzAj3dxdQH7lP/Y8VU/Nc7hC5LFh224o2EBzINZwOuYYqkWuWoSgEAJhWLkPlFtyg14o2AADwld1zLTBLWQAAAACAJRgpBACYZvdpNgCA4GL3XKMoBACYVi5rpsiU+94VAAB8ZvdcC8xSFgAAAABgCUYKAQCm2X2aDQAguNg91wKz1wAAAAAASzBSCAAwrdwdonILzoZa0QYAAL6ye65RFAIATHPLkMuCC/LdAXo/JwBAcLF7rgVmKQsAAAAAsARFIQDAtIppNlZsZmzcuFHXXHONkpOTZRiGVq5c6XmutLRU999/vzp16qSGDRsqOTlZt9xyiw4ePOjVxpEjRzR8+HDFxsYqLi5OmZmZKi4utuLHAgAIUHbPNYpCAIBpLrdh2WbGsWPH1LlzZz399NOVnjt+/Li2b9+uKVOmaPv27XrjjTe0a9cuXXvttV7HDR8+XF999ZXWrl2r1atXa+PGjRozZoxPPw8AQGCze65xTSEAIGBceeWVuvLKK6t8zuFwaO3atV77Fi5cqB49emj//v1q3ry5du7cqTVr1ujTTz9V9+7dJUkLFizQVVddpccee0zJycl1/hkAAKhQX3KNkUIAgGnlCrFsk6SioiKv7cSJE5b0s7CwUIZhKC4uTpK0efNmxcXFeYJTktLT0xUSEqItW7ZY8p4AgMBj91yjKAQAmGb1NJuUlBQ5HA7PNnfuXJ/7WFJSovvvv1833nijYmNjJUn5+flKSEjwOi4sLEzx8fHKz8/3+T0BAIHJ7rnG9FEAgN/l5uZ6Ak6SIiMjfWqvtLRUw4YNk9vt1qJFi3ztHgAApgRarlEUAgBMcylELgsmm1S0ERsb6xWevqgIzu+//17vvfeeV7tOp1OHDx/2Or6srExHjhyR0+m05P0BAIHH7rnG9FEAgGnlbsOyzUoVwbl7926tW7dOTZo08Xq+V69eKigo0LZt2zz73nvvPblcLvXs2dPSvgAAAofdc42RQgBAwCguLtaePXs8j/ft26ecnBzFx8crKSlJf/7zn7V9+3atXr1a5eXlnusp4uPjFRERoQ4dOmjgwIEaPXq0Fi9erNLSUo0bN0433HADK48CAM64+pJrFIUAANNqcy+m6toxY+vWrerXr5/n8cSJEyVJI0aM0PTp0/X2229Lkrp06eL1uvfff19paWmSpOXLl2vcuHG67LLLFBISoqFDh2r+/Pm1/xAAgIBn91yjKAQABIy0tDS53e5qnz/dcxXi4+O1YsUKK7sFAECt1JdcoygEAJjmdof
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Модель идеально классифицировала объекты, которые относятся к \"High difficulty\" и \"Low difficulty\"."
]
2024-11-15 16:47:21 +04:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Бизнес-цель 2: \n",
"Повышение удовлетворенности учеников онлайн-обучением на основе их устройств, типу соединения, местоположения.\n",
"\n",
"Регрессионная модель"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>9</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>876</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>382</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>634</th>\n",
" <td>University</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>906</th>\n",
" <td>School</td>\n",
" <td>Public</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1044</th>\n",
" <td>College</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1095</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1130</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1126</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"294 School Public Female 9 Mobile No \n",
"876 School Private Male 11 Mobile No \n",
"382 School Private Male 11 Mobile No \n",
"634 University Public Female 23 Mobile No \n",
"906 School Public Female 11 Mobile No \n",
"... ... ... ... ... ... ... \n",
"1044 College Private Female 18 Mobile No \n",
"1095 University Private Female 23 Computer Yes \n",
"1130 School Private Male 11 Mobile No \n",
"860 University Private Male 23 Mobile No \n",
"1126 University Private Male 23 Computer Yes \n",
"\n",
" Location Financial Condition Internet Type Network Type \n",
"294 Town Rich Mobile Data 4G \n",
"876 Town Mid Mobile Data 3G \n",
"382 Town Mid Mobile Data 3G \n",
"634 Town Mid Wifi 3G \n",
"906 Town Mid Wifi 3G \n",
"... ... ... ... ... \n",
"1044 Town Mid Wifi 4G \n",
"1095 Town Rich Wifi 4G \n",
"1130 Town Poor Wifi 4G \n",
"860 Town Mid Mobile Data 4G \n",
"1126 Rural Mid Mobile Data 3G \n",
"\n",
"[964 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Flexibility Level</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>876</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>382</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>634</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>906</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1044</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1095</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1130</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1126</th>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Flexibility Level\n",
"294 0\n",
"876 1\n",
"382 0\n",
"634 0\n",
"906 0\n",
"... ...\n",
"1044 1\n",
"1095 2\n",
"1130 0\n",
"860 0\n",
"1126 0\n",
"\n",
"[964 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Education Level</th>\n",
" <th>Institution Type</th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Device</th>\n",
" <th>IT Student</th>\n",
" <th>Location</th>\n",
" <th>Financial Condition</th>\n",
" <th>Internet Type</th>\n",
" <th>Network Type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>101</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>11</td>\n",
" <td>Computer</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>946</th>\n",
" <td>College</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>306</th>\n",
" <td>College</td>\n",
" <td>Public</td>\n",
" <td>Male</td>\n",
" <td>18</td>\n",
" <td>Tab</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>23</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1061</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Rural</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>908</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1135</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>18</td>\n",
" <td>Computer</td>\n",
" <td>Yes</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>894</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>10</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Poor</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>866</th>\n",
" <td>School</td>\n",
" <td>Private</td>\n",
" <td>Male</td>\n",
" <td>11</td>\n",
" <td>Mobile</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Mid</td>\n",
" <td>Mobile Data</td>\n",
" <td>3G</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1006</th>\n",
" <td>University</td>\n",
" <td>Private</td>\n",
" <td>Female</td>\n",
" <td>23</td>\n",
" <td>Computer</td>\n",
" <td>No</td>\n",
" <td>Town</td>\n",
" <td>Rich</td>\n",
" <td>Wifi</td>\n",
" <td>4G</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" Education Level Institution Type Gender Age Device IT Student \\\n",
"101 School Private Female 11 Computer No \n",
"946 College Private Male 18 Mobile No \n",
"306 College Public Male 18 Tab Yes \n",
"109 University Private Female 23 Mobile No \n",
"1061 University Private Male 23 Computer Yes \n",
"... ... ... ... ... ... ... \n",
"908 School Private Male 10 Mobile No \n",
"1135 University Private Female 18 Computer Yes \n",
"894 School Private Female 10 Mobile No \n",
"866 School Private Male 11 Mobile No \n",
"1006 University Private Female 23 Computer No \n",
"\n",
" Location Financial Condition Internet Type Network Type \n",
"101 Town Mid Wifi 4G \n",
"946 Town Mid Wifi 4G \n",
"306 Town Mid Wifi 4G \n",
"109 Town Mid Wifi 3G \n",
"1061 Rural Mid Mobile Data 3G \n",
"... ... ... ... ... \n",
"908 Town Rich Wifi 4G \n",
"1135 Town Mid Wifi 4G \n",
"894 Town Poor Mobile Data 3G \n",
"866 Town Mid Mobile Data 3G \n",
"1006 Town Rich Wifi 4G \n",
"\n",
"[241 rows x 10 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Flexibility Level</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>101</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>946</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>306</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>109</th>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1061</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>908</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1135</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>894</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>866</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1006</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>241 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" Flexibility Level\n",
"101 1\n",
"946 1\n",
"306 1\n",
"109 2\n",
"1061 1\n",
"... ...\n",
"908 1\n",
"1135 1\n",
"894 0\n",
"866 0\n",
"1006 1\n",
"\n",
"[241 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import math\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n",
"\n",
"df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n",
"\n",
"df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str,\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" X = df_input.drop(columns=[target_colname])\n",
" y = df_input[[target_colname]]\n",
"\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"Flexibility Level\", \n",
" frac_train=0.8, \n",
" random_state=42\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выполним one-hot encoding, чтобы избавиться от категориальных признаков."
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Age</th>\n",
" <th>Education Level_School</th>\n",
" <th>Education Level_University</th>\n",
" <th>Institution Type_Public</th>\n",
" <th>Gender_Male</th>\n",
" <th>Device_Mobile</th>\n",
" <th>Device_Tab</th>\n",
" <th>IT Student_Yes</th>\n",
" <th>Location_Town</th>\n",
" <th>Financial Condition_Poor</th>\n",
" <th>Financial Condition_Rich</th>\n",
" <th>Internet Type_Wifi</th>\n",
" <th>Network Type_3G</th>\n",
" <th>Network Type_4G</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>294</th>\n",
" <td>9</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>876</th>\n",
" <td>11</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>382</th>\n",
" <td>11</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>634</th>\n",
" <td>23</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>906</th>\n",
" <td>11</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1044</th>\n",
" <td>18</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1095</th>\n",
" <td>23</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1130</th>\n",
" <td>11</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>23</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1126</th>\n",
" <td>23</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>True</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>964 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
" Age Education Level_School Education Level_University \\\n",
"294 9 True False \n",
"876 11 True False \n",
"382 11 True False \n",
"634 23 False True \n",
"906 11 True False \n",
"... ... ... ... \n",
"1044 18 False False \n",
"1095 23 False True \n",
"1130 11 True False \n",
"860 23 False True \n",
"1126 23 False True \n",
"\n",
" Institution Type_Public Gender_Male Device_Mobile Device_Tab \\\n",
"294 True False True False \n",
"876 False True True False \n",
"382 False True True False \n",
"634 True False True False \n",
"906 True False True False \n",
"... ... ... ... ... \n",
"1044 False False True False \n",
"1095 False False False False \n",
"1130 False True True False \n",
"860 False True True False \n",
"1126 False True False False \n",
"\n",
" IT Student_Yes Location_Town Financial Condition_Poor \\\n",
"294 False True False \n",
"876 False True False \n",
"382 False True False \n",
"634 False True False \n",
"906 False True False \n",
"... ... ... ... \n",
"1044 False True False \n",
"1095 True True False \n",
"1130 False True True \n",
"860 False True False \n",
"1126 True False False \n",
"\n",
" Financial Condition_Rich Internet Type_Wifi Network Type_3G \\\n",
"294 True False False \n",
"876 False False True \n",
"382 False False True \n",
"634 False True True \n",
"906 False True True \n",
"... ... ... ... \n",
"1044 False True False \n",
"1095 True True False \n",
"1130 False True False \n",
"860 False False False \n",
"1126 False False True \n",
"\n",
" Network Type_4G \n",
"294 True \n",
"876 False \n",
"382 False \n",
"634 False \n",
"906 False \n",
"... ... \n",
"1044 True \n",
"1095 True \n",
"1130 True \n",
"860 True \n",
"1126 False \n",
"\n",
"[964 rows x 14 columns]"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cat_features = ['Education Level', 'Institution Type', 'Gender', 'Device', 'IT Student', 'Location', 'Financial Condition', 'Internet Type', 'Network Type']\n",
"\n",
"X_test = pd.get_dummies(X_test, columns=cat_features, drop_first=True)\n",
"X_train = pd.get_dummies(X_train, columns=cat_features, drop_first=True)\n",
"\n",
"X_test\n",
"X_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Определение перечня алгоритмов решения задачи регрессии."
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: linear\n",
"Model: linear_poly\n",
"Model: linear_interact\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}\n",
"\n",
"for model_name in models.keys():\n",
" print(f\"Model: {model_name}\")\n",
"\n",
" fitted_model = models[model_name][\"model\"].fit(\n",
" X_train.values, y_train.values.ravel()\n",
" )\n",
" y_train_pred = fitted_model.predict(X_train.values)\n",
" y_test_pred = fitted_model.predict(X_test.values)\n",
" models[model_name][\"fitted\"] = fitted_model\n",
" models[model_name][\"train_preds\"] = y_train_pred\n",
" models[model_name][\"preds\"] = y_test_pred\n",
" models[model_name][\"RMSE_train\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_train, y_train_pred)\n",
" )\n",
" models[model_name][\"RMSE_test\"] = math.sqrt(\n",
" metrics.mean_squared_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"RMAE_test\"] = math.sqrt(\n",
" metrics.mean_absolute_error(y_test, y_test_pred)\n",
" )\n",
" models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим результаты оценки."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_a8fc9_row0_col0, #T_a8fc9_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row0_col2, #T_a8fc9_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row0_col3, #T_a8fc9_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row1_col0 {\n",
" background-color: #238a8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row1_col1 {\n",
" background-color: #1f978b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row1_col2, #T_a8fc9_row2_col2 {\n",
" background-color: #6100a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row1_col3 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row2_col0 {\n",
" background-color: #1f998a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row2_col1 {\n",
" background-color: #1f9a8a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row2_col3 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row3_col0, #T_a8fc9_row4_col0 {\n",
" background-color: #1e9d89;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row3_col1, #T_a8fc9_row4_col1 {\n",
" background-color: #1fa088;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row3_col2, #T_a8fc9_row4_col2 {\n",
" background-color: #7801a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row3_col3, #T_a8fc9_row4_col3 {\n",
" background-color: #c03a83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row5_col0, #T_a8fc9_row6_col0 {\n",
" background-color: #67cc5c;\n",
" color: #000000;\n",
"}\n",
"#T_a8fc9_row5_col1 {\n",
" background-color: #5cc863;\n",
" color: #000000;\n",
"}\n",
"#T_a8fc9_row5_col2, #T_a8fc9_row6_col2 {\n",
" background-color: #bd3786;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row5_col3, #T_a8fc9_row6_col3 {\n",
" background-color: #8405a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_a8fc9_row6_col1 {\n",
" background-color: #5ec962;\n",
" color: #000000;\n",
"}\n",
"#T_a8fc9_row7_col0, #T_a8fc9_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_a8fc9\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_a8fc9_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_a8fc9_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_a8fc9_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_a8fc9_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_a8fc9_row0_col0\" class=\"data row0 col0\" >0.383913</td>\n",
" <td id=\"T_a8fc9_row0_col1\" class=\"data row0 col1\" >0.415442</td>\n",
" <td id=\"T_a8fc9_row0_col2\" class=\"data row0 col2\" >0.564953</td>\n",
" <td id=\"T_a8fc9_row0_col3\" class=\"data row0 col3\" >0.581728</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row1\" class=\"row_heading level0 row1\" >knn</th>\n",
" <td id=\"T_a8fc9_row1_col0\" class=\"data row1 col0\" >0.402696</td>\n",
" <td id=\"T_a8fc9_row1_col1\" class=\"data row1 col1\" >0.460020</td>\n",
" <td id=\"T_a8fc9_row1_col2\" class=\"data row1 col2\" >0.582800</td>\n",
" <td id=\"T_a8fc9_row1_col3\" class=\"data row1 col3\" >0.487148</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_a8fc9_row2_col0\" class=\"data row2 col0\" >0.431006</td>\n",
" <td id=\"T_a8fc9_row2_col1\" class=\"data row2 col1\" >0.465811</td>\n",
" <td id=\"T_a8fc9_row2_col2\" class=\"data row2 col2\" >0.582463</td>\n",
" <td id=\"T_a8fc9_row2_col3\" class=\"data row2 col3\" >0.474156</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
" <td id=\"T_a8fc9_row3_col0\" class=\"data row3 col0\" >0.437974</td>\n",
" <td id=\"T_a8fc9_row3_col1\" class=\"data row3 col1\" >0.476828</td>\n",
" <td id=\"T_a8fc9_row3_col2\" class=\"data row3 col2\" >0.604217</td>\n",
" <td id=\"T_a8fc9_row3_col3\" class=\"data row3 col3\" >0.448987</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row4\" class=\"row_heading level0 row4\" >linear_poly</th>\n",
" <td id=\"T_a8fc9_row4_col0\" class=\"data row4 col0\" >0.437146</td>\n",
" <td id=\"T_a8fc9_row4_col1\" class=\"data row4 col1\" >0.476920</td>\n",
" <td id=\"T_a8fc9_row4_col2\" class=\"data row4 col2\" >0.605206</td>\n",
" <td id=\"T_a8fc9_row4_col3\" class=\"data row4 col3\" >0.448773</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row5\" class=\"row_heading level0 row5\" >ridge</th>\n",
" <td id=\"T_a8fc9_row5_col0\" class=\"data row5 col0\" >0.536685</td>\n",
" <td id=\"T_a8fc9_row5_col1\" class=\"data row5 col1\" >0.564421</td>\n",
" <td id=\"T_a8fc9_row5_col2\" class=\"data row5 col2\" >0.682269</td>\n",
" <td id=\"T_a8fc9_row5_col3\" class=\"data row5 col3\" >0.227951</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row6\" class=\"row_heading level0 row6\" >linear</th>\n",
" <td id=\"T_a8fc9_row6_col0\" class=\"data row6 col0\" >0.536652</td>\n",
" <td id=\"T_a8fc9_row6_col1\" class=\"data row6 col1\" >0.564834</td>\n",
" <td id=\"T_a8fc9_row6_col2\" class=\"data row6 col2\" >0.682842</td>\n",
" <td id=\"T_a8fc9_row6_col3\" class=\"data row6 col3\" >0.226821</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_a8fc9_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_a8fc9_row7_col0\" class=\"data row7 col0\" >0.582720</td>\n",
" <td id=\"T_a8fc9_row7_col1\" class=\"data row7 col1\" >0.620961</td>\n",
" <td id=\"T_a8fc9_row7_col2\" class=\"data row7 col2\" >0.727896</td>\n",
" <td id=\"T_a8fc9_row7_col3\" class=\"data row7 col3\" >0.065525</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x239cb3282d0>"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Выводим лучшую модель."
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'random_forest'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Подбираем гиперпараметры методом поиска по сетке."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n",
"Лучшие параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n",
"Лучший результат (MSE): 0.15015918754440927\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
}
],
"source": [
"X = df[['Device', 'Financial Condition', 'Internet Type']]\n",
"y = df['Flexibility Level'] # Целевая переменная для регрессии\n",
"\n",
"model = RandomForestRegressor() \n",
"\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200], \n",
" 'max_depth': [None, 10, 20, 30], \n",
" 'min_samples_split': [2, 5, 10] \n",
"}\n",
"\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"print(\"Лучшие параметры:\", grid_search.best_params_)\n",
"print(\"Лучший результат (MSE):\", -grid_search.best_score_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Обучаем модель с новыми гиперпараметрами и сравниваем новых данных со старыми."
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n",
"d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n",
" return fit_method(estimator, *args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Старые параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 200}\n",
"Лучший результат (MSE) на старых параметрах: 0.14998947697586934\n",
"\n",
"Новые параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n",
"Лучший результат (MSE) на новых параметрах: 0.18737177399159283\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.13671335461532685\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.3697476904800446\n"
]
}
],
"source": [
"# Old data\n",
"\n",
"old_param_grid = param_grid\n",
"old_grid_search = grid_search\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ \n",
"\n",
"# New data\n",
"\n",
"new_param_grid = {\n",
" 'n_estimators': [50],\n",
" 'max_depth': [30],\n",
" 'min_samples_split': [2]\n",
" }\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid,\n",
" scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_\n",
"\n",
"new_best_model = RandomForestRegressor(**new_best_params)\n",
"new_best_model.fit(X_train, y_train)\n",
"\n",
"old_best_model = RandomForestRegressor(**old_best_params)\n",
"old_best_model.fit(X_train, y_train)\n",
"\n",
"y_new_pred = new_best_model.predict(X_test)\n",
"y_old_pred = old_best_model.predict(X_test)\n",
"\n",
"mse = metrics.mean_squared_error(y_test, y_new_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Визуализация данных"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA14AAAIjCAYAAAATE8pZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydeZwcVdX3f7eqe9bMkj2Z7AkhLCEsQWPACCJbwi6PCPrK4q7wPCAKPiiyCkEWBRXxUZEoiLIpCEQg7FuABAgQIGTfZ7JMMvt0dy33/aOnq25V3eqp7q7qqpq5Xz75UFVTXX27llv33HPO7xBKKYVAIBAIBAKBQCAQCAJDCrsBAoFAIBAIBAKBQDDQEYaXQCAQCAQCgUAgEASMMLwEAoFAIBAIBAKBIGCE4SUQCAQCgUAgEAgEASMML4FAIBAIBAKBQCAIGGF4CQQCgUAgEAgEAkHACMNLIBAIBAKBQCAQCAJGGF4CgUAgEAgEAoFAEDDC8BIIBAKBQCAQCASCgBGGl0AgEAgEAoHAV5544gmsWLHCWH/00Ufx4YcfhtcggSACCMNLIBhgrFu3Dt/5zncwdepUVFVVob6+HkceeSTuuOMO9Pb2ht08gUAgEAwCPvjgA1x88cVYs2YN3njjDXz3u99FZ2dn2M0SCEKFUEpp2I0QCAT+8OSTT+JLX/oSKisrce6552LmzJnIZDJ49dVX8cgjj+D888/HH/7wh7CbKRAIBIIBzq5du3DEEUdg7dq1AIAvfvGLeOSRR0JulUAQLsLwEggGCBs2bMCsWbMwfvx4PP/88xg7dqzl72vXrsWTTz6Jiy++OKQWCgQCgWAwkU6nsXLlStTU1GD//fcPuzkCQeiIUEOBYIBw8803o6urC3fffbfD6AKAffbZx2J0EUJw0UUX4W9/+xtmzJiBqqoqzJ49Gy+//LLlc5s2bcL3v/99zJgxA9XV1Rg+fDi+9KUvYePGjZb9Fi1aBEKI8a+mpgYHHXQQ/vSnP1n2O//88zFkyBBH+x5++GEQQvDiiy9atr/55ps48cQT0dDQgJqaGhx11FF47bXXLPtcc801IIRg9+7dlu3Lly8HIQSLFi2yfP/kyZMt+23ZsgXV1dUghDh+13/+8x/MmzcPtbW1qKurw0knneQpT8F+Puz/rrnmGkf7V61ahbPOOgv19fUYPnw4Lr74YqRSKcex77vvPsyePRvV1dUYNmwYzj77bGzZsoXbDrfvt5/nVCqFa665Bvvuuy+qqqowduxYfPGLX8S6desAABs3bnScy87OTsyePRtTpkxBc3Ozsf3WW2/FEUccgeHDh6O6uhqzZ8/Gww8/bPm+1tZWzJ8/H+PHj0dlZSXGjh2Lr371q9i0aZNlPy/Hyv3Oiy66yLH95JNPtlzv3O+49dZbHfvOnDkTRx99tLH+4osvghDC/b4c9vvp6quvhiRJeO655yz7ffvb30ZFRQXee+8912Plfgd7bwDALbfcAkKIpW1+fD7fNc+dp3z/zj//fADmvc4+O7quY9asWdznz+vzf/TRR2PmzJmOfW+99VbH902ePBknn3yy63nJXcvc8T/++GNUV1fj3HPPtez36quvQpZl/PjHP3Y9FpB9Zg844AAMGTIE9fX1+MxnPoNHH33Usk8h7X/sscdw0kknoampCZWVlZg2bRquv/56aJpm+Szv+vLOP+Ct7yr0etjvoWXLlhn3A6+dlZWVmD17Nvbff/+C7mOBYKCSCLsBAoHAHx5//HFMnToVRxxxhOfPvPTSS3jggQfwP//zP6isrMTvfvc7nHjiiXjrrbeMAcOyZcvw+uuv4+yzz8b48eOxceNG3HXXXTj66KPx0UcfoaamxnLMX/3qVxgxYgQ6Ojrw5z//Gd/61rcwefJkHHvssQX/pueffx7z58/H7NmzjQHtPffcg2OOOQavvPIKPv3pTxd8TB5XXXUV18C59957cd555+GEE07AL37xC/T09OCuu+7CZz/7Wbz77rsOA47HddddhylTphjrXV1d+N73vsfd96yzzsLkyZOxcOFCvPHGG/j1r3+NvXv34q9//auxzw033ICf/exnOOuss/DNb34Tu3btwm9+8xt87nOfw7vvvovGxkbHcY877jhjgLls2TL8+te/tvxd0zScfPLJeO6553D22Wfj4osvRmdnJ5YsWYKVK1di2rRpjmMqioIzzzwTmzdvxmuvvWYx9u+44w6ceuqp+OpXv4pMJoN//OMf+NKXvoQnnngCJ510EgAgk8mgrq4OF198MYYPH45169bhN7/5Dd5//3188MEHBR0rSlx55ZV4/PHH8Y1vfAMffPAB6urq8PTTT+OPf/wjrr/+ehx88MEFHa+trQ0LFy4suj1un+/vmh977LG49957jf3/+c9/4l//+pdlG+++yHHvvfdarmPU2H///XH99dfjsssuw3/913/h1FNPRXd3N84//3zst99+uO666/J+vru7G2eccQYmT56M3t5eLFq0CGeeeSaWLl1aVL+0aNEiDBkyBJdeeimGDBmC559/HldddRU6Ojpwyy23FHw8P/ouL/RnoOYo9T4WCAYMVCAQxJ729nYKgJ522mmePwOAAqDLly83tm3atIlWVVXRM844w9jW09Pj+OzSpUspAPrXv/7V2HbPPfdQAHTDhg3GttWrV1MA9Oabbza2nXfeebS2ttZxzIceeogCoC+88AKllFJd1+n06dPpCSecQHVdt7RnypQp9LjjjjO2XX311RQA3bVrl+WYy5YtowDoPffcY/n+SZMmGesrV66kkiTR+fPnW9rf2dlJGxsb6be+9S3LMVtaWmhDQ4Nju53c+Vi2bJll+65duygAevXVVzvaf+qpp1r2/f73v08B0Pfee49SSunGjRupLMv0hhtusOz3wQcf0EQi4dieyWQoAHrRRRcZ2+znmVJK//znP1MA9Je//KXjd+TO/YYNG4xzqes6/epXv0pramrom2++6fiM/Z7JZDJ05syZ9JhjjnHsy3LzzTdTAHT37t0FHwsAvfDCCx3HPOmkkyzXO/c7brnlFse+Bx54ID3qqKOM9RdeeIECoA899JBrm+33E6XZ61FRUUG/+c1v0r1799Jx48bRww8/nCqK4noc9new98bll19OR40aRWfPnm1pW6mf93LNWXL3KA/7s59KpejEiRONZ8r+/Hl5/iml9KijjqIHHnigY99bbrnF0ddMmjSJnnTSSdz2UWpeS/b4mqbRz372s3T06NF09+7d9MILL6SJRMLxzHph586dFAC99dZbi2o/r5/9zne+Q2tqamgqlTK2EULoVVddZdnPfv4L6bsKvR7sPbR48WIKgJ544omOe6PU+1ggGKiIUEOBYADQ0dEBAKirqyvoc3PnzsXs2bON9YkTJ+K0007D008/bYS4VFdXG39XFAWtra3YZ5990NjYiHfeecdxzL1792L37t1Yv349fvWrX0GWZRx11FGO/Xbv3m35Z1e7WrFiBdasWYOvfOUraG1tNfbr7u7GF77wBbz88svQdd3ymT179liO2d7e3u85uOKKK3DYYYfhS1/6kmX7kiVL0NbWhnPOOcdyTFmWMWfOHLzwwgv9HrtQLrzwQsv6f//3fwMAFi9eDCDrddB1HWeddZalTWPGjMH06dMdbcp58aqqqvJ+7yOPPIIRI0YY38diDyECgMsuuwx/+9vf8OCDD3Jn99l7Zu/evWhvb8e8efO490tnZyd27tyJpUuX4u9//zsOPPBADBs2rKhjpVIpx32lKAr3N/f09Dj2tYd1sW3cvXs32trauH+3M3PmTFx77bX405/+hBNOOAG7d+/GX/7yFyQShQWZbNu2Db/5zW/ws5/9jBsOVsrnC73mhXDnnXeitbUVV199tes+/T3/OTRNc+zb09PD3VdRFOzevRutra1QVbXfdkqShEWLFqGrqwvz58/H7373O1xxxRU4/PDDPf3O3PetW7cON910EyRJwpFHHllU+9n7PHe/zZs3Dz09PVi1apXxt1GjRmHr1q1521VM3+X1euSglOKKK67AmWeeiTlz5uTdt9T7WCAYSIhQQ4FgAFBfXw8
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 6))\n",
"plt.plot(y_test.values, label='Истинные значения', color='blue', linewidth=2)\n",
"plt.plot(y_old_pred, label='Предсказанные значения (старые данные)', color='red', linestyle='--', linewidth=2)\n",
"plt.plot(y_new_pred, label='Предсказанные значения (новые данные)', color='green', linestyle='-', linewidth=2)\n",
"\n",
"plt.title('Сравнение предсказанных и истинных значений')\n",
"plt.xlabel('Подбор параметров')\n",
"plt.ylabel('Значения')\n",
"plt.grid()\n",
"plt.legend(loc ='lower right')\n",
"plt.show()"
]
2024-11-15 00:44:23 +04:00
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}