From c5a7faf9230067c250454b8d81fc24e077144c8d Mon Sep 17 00:00:00 2001 From: shirotame Date: Fri, 15 Nov 2024 00:44:23 +0400 Subject: [PATCH 1/2] 1 business goal of 2 --- lab_4/lab4.ipynb | 2391 ++++++++++++++++++++++++++++++++++++++++ lab_4/requirements.txt | Bin 0 -> 2088 bytes 2 files changed, 2391 insertions(+) create mode 100644 lab_4/lab4.ipynb create mode 100644 lab_4/requirements.txt diff --git a/lab_4/lab4.ipynb b/lab_4/lab4.ipynb new file mode 100644 index 0000000..733cd47 --- /dev/null +++ b/lab_4/lab4.ipynb @@ -0,0 +1,2391 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Лабораторная 4\n", + "\n", + "Датасет: Информация об онлайн обучении учеников\n", + "\n", + "Бизнес-цель 1: Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения." + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Index(['Education Level', 'Institution Type', 'Gender', 'Age', 'Device',\n", + " 'IT Student', 'Location', 'Financial Condition', 'Internet Type',\n", + " 'Network Type', 'Flexibility Level'],\n", + " dtype='object')\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "from typing import Tuple\n", + "from pandas import DataFrame\n", + "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, metrics, set_config\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.discriminant_analysis import StandardScaler\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "from sklearn.compose import ColumnTransformer\n", + "from sklearn.discriminant_analysis import StandardScaler\n", + "from sklearn.impute import SimpleImputer\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "from sklearn.metrics import ConfusionMatrixDisplay\n", + "from sklearn.model_selection import GridSearchCV\n", + "\n", + "set_config(transform_output=\"pandas\")\n", + "df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n", + "print(df.columns)\n", + "\n", + "map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n", + "\n", + "df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Предварительно создадим колонку для работы с ней (ключевой фактор)" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "metadata": {}, + "outputs": [], + "source": [ + "fincond_mapping = {'Poor': 2, 'Mid': 1, 'Rich': 0}\n", + "internet_type_mapping = {'Mobile Data': 1, 'Wifi': 0}\n", + "device_mapping = {'Mobile': 1, 'Computer': 0}\n", + "network_type = {'2G': 2, '3G': 1, '4G': 0}\n", + "\n", + "df['Financial Score'] = df['Financial Condition'].map(fincond_mapping)\n", + "df['Internet Score'] = df['Internet Type'].map(internet_type_mapping)\n", + "df['Device Score'] = df['Device'].map(device_mapping)\n", + "df['Network Score'] = df['Network Type'].map(network_type)\n", + "\n", + "df['Access Difficulty Score'] = df['Financial Score'] + df['Internet Score'] + df['Device Score'] + df['Network Score']\n", + "\n", + "df['Access Difficulty'] = (df['Access Difficulty Score'] >= 3).astype(int)\n", + "df.drop(columns=['Financial Score', 'Device Score', 'Internet Score', 'Network Score', 'Access Difficulty Score'], inplace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Формируем выборки" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'X_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork TypeFlexibility LevelAccess Difficulty
649SchoolPublicMale18MobileNoTownMidWifi4G10
637SchoolPrivateFemale9MobileNoTownMidMobile Data4G11
68SchoolPublicFemale11MobileNoTownMidWifi4G00
276UniversityPrivateFemale18MobileYesTownMidMobile Data3G01
547SchoolPublicMale11MobileNoTownMidWifi4G10
.......................................
1097UniversityPrivateMale23MobileYesTownRichWifi4G00
854SchoolPublicFemale18MobileNoTownMidMobile Data4G01
756UniversityPublicMale18ComputerNoTownMidWifi3G10
133CollegePublicMale18MobileNoTownPoorMobile Data4G01
53UniversityPublicMale27MobileYesRuralPoorMobile Data4G11
\n", + "

964 rows × 12 columns

\n", + "
" + ], + "text/plain": [ + " Education Level Institution Type Gender Age Device IT Student \\\n", + "649 School Public Male 18 Mobile No \n", + "637 School Private Female 9 Mobile No \n", + "68 School Public Female 11 Mobile No \n", + "276 University Private Female 18 Mobile Yes \n", + "547 School Public Male 11 Mobile No \n", + "... ... ... ... ... ... ... \n", + "1097 University Private Male 23 Mobile Yes \n", + "854 School Public Female 18 Mobile No \n", + "756 University Public Male 18 Computer No \n", + "133 College Public Male 18 Mobile No \n", + "53 University Public Male 27 Mobile Yes \n", + "\n", + " Location Financial Condition Internet Type Network Type \\\n", + "649 Town Mid Wifi 4G \n", + "637 Town Mid Mobile Data 4G \n", + "68 Town Mid Wifi 4G \n", + "276 Town Mid Mobile Data 3G \n", + "547 Town Mid Wifi 4G \n", + "... ... ... ... ... \n", + "1097 Town Rich Wifi 4G \n", + "854 Town Mid Mobile Data 4G \n", + "756 Town Mid Wifi 3G \n", + "133 Town Poor Mobile Data 4G \n", + "53 Rural Poor Mobile Data 4G \n", + "\n", + " Flexibility Level Access Difficulty \n", + "649 1 0 \n", + "637 1 1 \n", + "68 0 0 \n", + "276 0 1 \n", + "547 1 0 \n", + "... ... ... \n", + "1097 0 0 \n", + "854 0 1 \n", + "756 1 0 \n", + "133 0 1 \n", + "53 1 1 \n", + "\n", + "[964 rows x 12 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Access Difficulty
6490
6371
680
2761
5470
......
10970
8541
7560
1331
531
\n", + "

964 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " Access Difficulty\n", + "649 0\n", + "637 1\n", + "68 0\n", + "276 1\n", + "547 0\n", + "... ...\n", + "1097 0\n", + "854 1\n", + "756 0\n", + "133 1\n", + "53 1\n", + "\n", + "[964 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'X_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork TypeFlexibility LevelAccess Difficulty
265SchoolPrivateFemale9MobileNoTownPoorWifi4G11
358SchoolPrivateFemale10MobileNoTownMidMobile Data3G11
316UniversityPrivateMale23TabNoTownMidWifi4G10
907SchoolPrivateFemale9MobileNoTownPoorMobile Data4G11
1042UniversityPrivateMale23MobileNoTownMidMobile Data3G11
.......................................
421SchoolPrivateFemale10MobileNoTownMidMobile Data3G11
936UniversityPrivateMale23TabNoTownRichWifi4G20
722UniversityPrivateMale23MobileYesRuralPoorMobile Data3G11
1075UniversityPrivateMale23ComputerYesTownMidWifi4G00
577UniversityPrivateMale23MobileYesTownMidWifi4G00
\n", + "

241 rows × 12 columns

\n", + "
" + ], + "text/plain": [ + " Education Level Institution Type Gender Age Device IT Student \\\n", + "265 School Private Female 9 Mobile No \n", + "358 School Private Female 10 Mobile No \n", + "316 University Private Male 23 Tab No \n", + "907 School Private Female 9 Mobile No \n", + "1042 University Private Male 23 Mobile No \n", + "... ... ... ... ... ... ... \n", + "421 School Private Female 10 Mobile No \n", + "936 University Private Male 23 Tab No \n", + "722 University Private Male 23 Mobile Yes \n", + "1075 University Private Male 23 Computer Yes \n", + "577 University Private Male 23 Mobile Yes \n", + "\n", + " Location Financial Condition Internet Type Network Type \\\n", + "265 Town Poor Wifi 4G \n", + "358 Town Mid Mobile Data 3G \n", + "316 Town Mid Wifi 4G \n", + "907 Town Poor Mobile Data 4G \n", + "1042 Town Mid Mobile Data 3G \n", + "... ... ... ... ... \n", + "421 Town Mid Mobile Data 3G \n", + "936 Town Rich Wifi 4G \n", + "722 Rural Poor Mobile Data 3G \n", + "1075 Town Mid Wifi 4G \n", + "577 Town Mid Wifi 4G \n", + "\n", + " Flexibility Level Access Difficulty \n", + "265 1 1 \n", + "358 1 1 \n", + "316 1 0 \n", + "907 1 1 \n", + "1042 1 1 \n", + "... ... ... \n", + "421 1 1 \n", + "936 2 0 \n", + "722 1 1 \n", + "1075 0 0 \n", + "577 0 0 \n", + "\n", + "[241 rows x 12 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Access Difficulty
2651
3581
3160
9071
10421
......
4211
9360
7221
10750
5770
\n", + "

241 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " Access Difficulty\n", + "265 1\n", + "358 1\n", + "316 0\n", + "907 1\n", + "1042 1\n", + "... ...\n", + "421 1\n", + "936 0\n", + "722 1\n", + "1075 0\n", + "577 0\n", + "\n", + "[241 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def split_stratified_into_train_val_test(\n", + " df_input,\n", + " stratify_colname=\"y\",\n", + " frac_train=0.6,\n", + " frac_val=0.15,\n", + " frac_test=0.25,\n", + " random_state=None,\n", + ") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n", + " \n", + " if frac_train + frac_val + frac_test != 1.0:\n", + " raise ValueError(\n", + " \"fractions %f, %f, %f do not add up to 1.0\"\n", + " % (frac_train, frac_val, frac_test)\n", + " )\n", + " if stratify_colname not in df_input.columns:\n", + " raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n", + " X = df_input\n", + " y = df_input[\n", + " [stratify_colname]\n", + " ]\n", + " df_train, df_temp, y_train, y_temp = train_test_split(\n", + " X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n", + " )\n", + " if frac_val <= 0:\n", + " assert len(df_input) == len(df_train) + len(df_temp)\n", + " return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n", + " \n", + " relative_frac_test = frac_test / (frac_val + frac_test)\n", + " df_val, df_test, y_val, y_test = train_test_split(\n", + " df_temp,\n", + " y_temp,\n", + " stratify=y_temp,\n", + " test_size=relative_frac_test,\n", + " random_state=random_state,\n", + " )\n", + " assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n", + " return df_train, df_val, df_test, y_train, y_val, y_test\n", + "\n", + "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", + " df, stratify_colname=\"Access Difficulty\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=9\n", + ")\n", + "\n", + "display(\"X_train\", X_train)\n", + "display(\"y_train\", y_train)\n", + "\n", + "display(\"X_test\", X_test)\n", + "display(\"y_test\", y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Пропущенные значения по столбцам:\n", + "Education Level 0\n", + "Institution Type 0\n", + "Gender 0\n", + "Age 0\n", + "Device 0\n", + "IT Student 0\n", + "Location 0\n", + "Financial Condition 0\n", + "Internet Type 0\n", + "Network Type 0\n", + "Flexibility Level 0\n", + "Access Difficulty 0\n", + "dtype: int64\n", + "\n", + "Статистический обзор данных:\n", + " Age Flexibility Level Access Difficulty\n", + "count 1205.000000 1205.000000 1205.000000\n", + "mean 17.065560 0.684647 0.624896\n", + "std 5.830369 0.618221 0.484351\n", + "min 9.000000 0.000000 0.000000\n", + "25% 11.000000 0.000000 0.000000\n", + "50% 18.000000 1.000000 1.000000\n", + "75% 23.000000 1.000000 1.000000\n", + "max 27.000000 2.000000 1.000000\n" + ] + } + ], + "source": [ + "null_values = df.isnull().sum()\n", + "print(\"Пропущенные значения по столбцам:\")\n", + "print(null_values)\n", + "\n", + "stat_summary = df.describe()\n", + "print(\"\\nСтатистический обзор данных:\")\n", + "print(stat_summary)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Формируем конвеер для классификации данных и проверка конвеера" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Access DifficultyInstitution Type_PublicDevice_MobileDevice_TabLocation_TownFinancial Condition_PoorFinancial Condition_RichInternet Type_WifiNetwork Type_3GNetwork Type_4G
649-1.2895671.01.00.01.00.00.01.00.01.0
6370.7754540.01.00.01.00.00.00.00.01.0
68-1.2895671.01.00.01.00.00.01.00.01.0
2760.7754540.01.00.01.00.00.00.01.00.0
547-1.2895671.01.00.01.00.00.01.00.01.0
.................................
1097-1.2895670.01.00.01.00.01.01.00.01.0
8540.7754541.01.00.01.00.00.00.00.01.0
756-1.2895671.00.00.01.00.00.01.01.00.0
1330.7754541.01.00.01.01.00.00.00.01.0
530.7754541.01.00.00.01.00.00.00.01.0
\n", + "

964 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n", + "649 -1.289567 1.0 1.0 0.0 \n", + "637 0.775454 0.0 1.0 0.0 \n", + "68 -1.289567 1.0 1.0 0.0 \n", + "276 0.775454 0.0 1.0 0.0 \n", + "547 -1.289567 1.0 1.0 0.0 \n", + "... ... ... ... ... \n", + "1097 -1.289567 0.0 1.0 0.0 \n", + "854 0.775454 1.0 1.0 0.0 \n", + "756 -1.289567 1.0 0.0 0.0 \n", + "133 0.775454 1.0 1.0 0.0 \n", + "53 0.775454 1.0 1.0 0.0 \n", + "\n", + " Location_Town Financial Condition_Poor Financial Condition_Rich \\\n", + "649 1.0 0.0 0.0 \n", + "637 1.0 0.0 0.0 \n", + "68 1.0 0.0 0.0 \n", + "276 1.0 0.0 0.0 \n", + "547 1.0 0.0 0.0 \n", + "... ... ... ... \n", + "1097 1.0 0.0 1.0 \n", + "854 1.0 0.0 0.0 \n", + "756 1.0 0.0 0.0 \n", + "133 1.0 1.0 0.0 \n", + "53 0.0 1.0 0.0 \n", + "\n", + " Internet Type_Wifi Network Type_3G Network Type_4G \n", + "649 1.0 0.0 1.0 \n", + "637 0.0 0.0 1.0 \n", + "68 1.0 0.0 1.0 \n", + "276 0.0 1.0 0.0 \n", + "547 1.0 0.0 1.0 \n", + "... ... ... ... \n", + "1097 1.0 0.0 1.0 \n", + "854 0.0 0.0 1.0 \n", + "756 1.0 1.0 0.0 \n", + "133 0.0 0.0 1.0 \n", + "53 0.0 0.0 1.0 \n", + "\n", + "[964 rows x 10 columns]" + ] + }, + "execution_count": 108, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "columns_to_drop = ['Age', 'Education Level', 'Gender', 'IT Student', 'Flexibility Level']\n", + "num_columns = [\n", + " column\n", + " for column in df.columns\n", + " if column not in columns_to_drop and df[column].dtype != \"object\"\n", + "]\n", + "cat_columns = [\n", + " column\n", + " for column in df.columns\n", + " if column not in columns_to_drop and df[column].dtype == \"object\"\n", + "]\n", + "\n", + "num_imputer = SimpleImputer(strategy=\"median\")\n", + "num_scaler = StandardScaler()\n", + "preprocessing_num = Pipeline(\n", + " [\n", + " (\"imputer\", num_imputer),\n", + " (\"scaler\", num_scaler),\n", + " ]\n", + ")\n", + "\n", + "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n", + "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n", + "preprocessing_cat = Pipeline(\n", + " [\n", + " (\"imputer\", cat_imputer),\n", + " (\"encoder\", cat_encoder),\n", + " ]\n", + ")\n", + "\n", + "features_preprocessing = ColumnTransformer(\n", + " verbose_feature_names_out=False,\n", + " transformers=[\n", + " (\"prepocessing_num\", preprocessing_num, num_columns),\n", + " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n", + " ],\n", + " remainder=\"passthrough\"\n", + ")\n", + "\n", + "drop_columns = ColumnTransformer(\n", + " verbose_feature_names_out=False,\n", + " transformers=[\n", + " (\"drop_columns\", \"drop\", columns_to_drop),\n", + " ],\n", + " remainder=\"passthrough\",\n", + ")\n", + "\n", + "\n", + "pipeline_end = Pipeline(\n", + " [\n", + " (\"features_preprocessing\", features_preprocessing),\n", + " (\"drop_columns\", drop_columns),\n", + " ]\n", + ")\n", + "\n", + "preprocessing_result = pipeline_end.fit_transform(X_train)\n", + "preprocessed_df = pd.DataFrame(\n", + " preprocessing_result,\n", + " columns=pipeline_end.get_feature_names_out(),\n", + ")\n", + "\n", + "preprocessed_df" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Формируем набор моделей" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [], + "source": [ + "class_models = {\n", + " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n", + " \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n", + " \"decision_tree\": {\n", + " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=9)\n", + " },\n", + " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n", + " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n", + " \"gradient_boosting\": {\n", + " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n", + " },\n", + " \"random_forest\": {\n", + " \"model\": ensemble.RandomForestClassifier(\n", + " max_depth=11, class_weight=\"balanced\", random_state=9\n", + " )\n", + " },\n", + " \"mlp\": {\n", + " \"model\": neural_network.MLPClassifier(\n", + " hidden_layer_sizes=(7,),\n", + " max_iter=500,\n", + " early_stopping=True,\n", + " random_state=9,\n", + " )\n", + " },\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Обучаем модели и тестируем их" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: logistic\n", + "Model: ridge\n", + "Model: decision_tree\n", + "Model: knn\n", + "Model: naive_bayes\n", + "Model: gradient_boosting\n", + "Model: random_forest\n", + "Model: mlp\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\metrics\\_classification.py:1531: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n", + " _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n" + ] + } + ], + "source": [ + "for model_name in class_models.keys():\n", + " print(f\"Model: {model_name}\")\n", + " model = class_models[model_name][\"model\"]\n", + "\n", + " model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n", + " model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n", + "\n", + " y_train_predict = model_pipeline.predict(X_train)\n", + " y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n", + " y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n", + "\n", + " class_models[model_name][\"pipeline\"] = model_pipeline\n", + " class_models[model_name][\"probs\"] = y_test_probs\n", + " class_models[model_name][\"preds\"] = y_test_predict\n", + "\n", + " class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n", + " y_train, y_train_predict\n", + " )\n", + " class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n", + " y_test, y_test_predict\n", + " )\n", + " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n", + " y_train, y_train_predict\n", + " )\n", + " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n", + " y_test, y_test_predict\n", + " )\n", + " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n", + " y_train, y_train_predict\n", + " )\n", + " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n", + " y_test, y_test_predict\n", + " )\n", + " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n", + " y_test, y_test_probs\n", + " )\n", + " class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict, average=None)\n", + " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict, average=None)\n", + " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n", + " y_test, y_test_predict\n", + " )\n", + " class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n", + " y_test, y_test_predict\n", + " )\n", + " class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n", + " y_test, y_test_predict\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n", + "for index, key in enumerate(class_models.keys()):\n", + " c_matrix = class_models[key][\"Confusion_matrix\"]\n", + " disp = ConfusionMatrixDisplay(\n", + " confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n", + " ).plot(ax=ax.flat[index])\n", + " disp.ax_.set_title(key)\n", + "\n", + "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Точность, полнота, верность (аккуратность), F-мера" + ] + }, + { + "cell_type": "code", + "execution_count": 112, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
logistic1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
ridge1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
decision_tree1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
knn1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
naive_bayes1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
gradient_boosting1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
random_forest1.01.01.01.01.0000001.000000[1.0, 1.0][1.0, 1.0]
mlp0.00.00.00.00.3755190.373444[0.5460030165912518, 0.0][0.5438066465256798, 0.0]
\n", + "
" + ], + "text/plain": [ + " Precision_train Precision_test Recall_train Recall_test \\\n", + "logistic 1.0 1.0 1.0 1.0 \n", + "ridge 1.0 1.0 1.0 1.0 \n", + "decision_tree 1.0 1.0 1.0 1.0 \n", + "knn 1.0 1.0 1.0 1.0 \n", + "naive_bayes 1.0 1.0 1.0 1.0 \n", + "gradient_boosting 1.0 1.0 1.0 1.0 \n", + "random_forest 1.0 1.0 1.0 1.0 \n", + "mlp 0.0 0.0 0.0 0.0 \n", + "\n", + " Accuracy_train Accuracy_test F1_train \\\n", + "logistic 1.000000 1.000000 [1.0, 1.0] \n", + "ridge 1.000000 1.000000 [1.0, 1.0] \n", + "decision_tree 1.000000 1.000000 [1.0, 1.0] \n", + "knn 1.000000 1.000000 [1.0, 1.0] \n", + "naive_bayes 1.000000 1.000000 [1.0, 1.0] \n", + "gradient_boosting 1.000000 1.000000 [1.0, 1.0] \n", + "random_forest 1.000000 1.000000 [1.0, 1.0] \n", + "mlp 0.375519 0.373444 [0.5460030165912518, 0.0] \n", + "\n", + " F1_test \n", + "logistic [1.0, 1.0] \n", + "ridge [1.0, 1.0] \n", + "decision_tree [1.0, 1.0] \n", + "knn [1.0, 1.0] \n", + "naive_bayes [1.0, 1.0] \n", + "gradient_boosting [1.0, 1.0] \n", + "random_forest [1.0, 1.0] \n", + "mlp [0.5438066465256798, 0.0] " + ] + }, + "execution_count": 112, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", + " [\n", + " \"Precision_train\",\n", + " \"Precision_test\",\n", + " \"Recall_train\",\n", + " \"Recall_test\",\n", + " \"Accuracy_train\",\n", + " \"Accuracy_test\",\n", + " \"F1_train\",\n", + " \"F1_test\",\n", + " ]\n", + "]\n", + "class_metrics.sort_values(\n", + " by=\"Accuracy_test\", ascending=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
logistic1.000000[1.0, 1.0]1.0000001.01.0
ridge1.000000[1.0, 1.0]1.0000001.01.0
decision_tree1.000000[1.0, 1.0]1.0000001.01.0
knn1.000000[1.0, 1.0]1.0000001.01.0
naive_bayes1.000000[1.0, 1.0]1.0000001.01.0
gradient_boosting1.000000[1.0, 1.0]1.0000001.01.0
random_forest1.000000[1.0, 1.0]1.0000001.01.0
mlp0.373444[0.5438066465256798, 0.0]0.0680650.00.0
\n", + "
" + ], + "text/plain": [ + " Accuracy_test F1_test ROC_AUC_test \\\n", + "logistic 1.000000 [1.0, 1.0] 1.000000 \n", + "ridge 1.000000 [1.0, 1.0] 1.000000 \n", + "decision_tree 1.000000 [1.0, 1.0] 1.000000 \n", + "knn 1.000000 [1.0, 1.0] 1.000000 \n", + "naive_bayes 1.000000 [1.0, 1.0] 1.000000 \n", + "gradient_boosting 1.000000 [1.0, 1.0] 1.000000 \n", + "random_forest 1.000000 [1.0, 1.0] 1.000000 \n", + "mlp 0.373444 [0.5438066465256798, 0.0] 0.068065 \n", + "\n", + " Cohen_kappa_test MCC_test \n", + "logistic 1.0 1.0 \n", + "ridge 1.0 1.0 \n", + "decision_tree 1.0 1.0 \n", + "knn 1.0 1.0 \n", + "naive_bayes 1.0 1.0 \n", + "gradient_boosting 1.0 1.0 \n", + "random_forest 1.0 1.0 \n", + "mlp 0.0 0.0 " + ] + }, + "execution_count": 113, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", + " [\n", + " \"Accuracy_test\",\n", + " \"F1_test\",\n", + " \"ROC_AUC_test\",\n", + " \"Cohen_kappa_test\",\n", + " \"MCC_test\",\n", + " ]\n", + "]\n", + "class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Лучшая модель" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'logistic'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n", + "\n", + "display(best_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Находим ошибки" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Error items count: 0'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelPredictedInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork TypeFlexibility LevelAccess Difficulty
\n", + "
" + ], + "text/plain": [ + "Empty DataFrame\n", + "Columns: [Education Level, Predicted, Institution Type, Gender, Age, Device, IT Student, Location, Financial Condition, Internet Type, Network Type, Flexibility Level, Access Difficulty]\n", + "Index: []" + ] + }, + "execution_count": 115, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preprocessing_result = pipeline_end.transform(X_test)\n", + "preprocessed_df = pd.DataFrame(\n", + " preprocessing_result,\n", + " columns=pipeline_end.get_feature_names_out(),\n", + ")\n", + "\n", + "y_pred = class_models[best_model][\"preds\"]\n", + "\n", + "error_index = y_test[y_test[\"Access Difficulty\"] != y_pred].index.tolist()\n", + "display(f\"Error items count: {len(error_index)}\")\n", + "\n", + "error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n", + "error_df = X_test.loc[error_index].copy()\n", + "error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n", + "error_df.sort_index()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Пример использования модели (конвейера) для предсказания" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork TypeFlexibility LevelAccess Difficulty
450SchoolPrivateFemale11MobileNoTownPoorMobile Data4G11
\n", + "
" + ], + "text/plain": [ + " Education Level Institution Type Gender Age Device IT Student Location \\\n", + "450 School Private Female 11 Mobile No Town \n", + "\n", + " Financial Condition Internet Type Network Type Flexibility Level \\\n", + "450 Poor Mobile Data 4G 1 \n", + "\n", + " Access Difficulty \n", + "450 1 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Access DifficultyInstitution Type_PublicDevice_MobileDevice_TabLocation_TownFinancial Condition_PoorFinancial Condition_RichInternet Type_WifiNetwork Type_3GNetwork Type_4G
4500.7754540.01.00.01.01.00.00.00.01.0
\n", + "
" + ], + "text/plain": [ + " Access Difficulty Institution Type_Public Device_Mobile Device_Tab \\\n", + "450 0.775454 0.0 1.0 0.0 \n", + "\n", + " Location_Town Financial Condition_Poor Financial Condition_Rich \\\n", + "450 1.0 1.0 0.0 \n", + "\n", + " Internet Type_Wifi Network Type_3G Network Type_4G \n", + "450 0.0 0.0 1.0 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'predicted: 1 (proba: [0.00310819 0.99689181])'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'real: 1'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model = class_models[best_model][\"pipeline\"]\n", + "\n", + "example_id = 450\n", + "test = pd.DataFrame(X_test.loc[example_id, :]).T\n", + "test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n", + "display(test)\n", + "display(test_preprocessed)\n", + "result_proba = model.predict_proba(test)[0]\n", + "result = model.predict(test)[0]\n", + "real = int(y_test.loc[example_id].values[0])\n", + "display(f\"predicted: {result} (proba: {result_proba})\")\n", + "display(f\"real: {real}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Создаем гиперпараметры методом поиска по сетке. " + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'model__criterion': 'gini',\n", + " 'model__max_depth': 2,\n", + " 'model__max_features': 'sqrt',\n", + " 'model__n_estimators': 10}" + ] + }, + "execution_count": 121, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimized_model_type = 'random_forest'\n", + "random_state = 9\n", + "\n", + "random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n", + "\n", + "param_grid = {\n", + " \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n", + " \"model__max_features\": [\"sqrt\", \"log2\", 2],\n", + " \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n", + " \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n", + "}\n", + "\n", + "gs_optomizer = GridSearchCV(\n", + " estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n", + ")\n", + "gs_optomizer.fit(X_train, y_train.values.ravel())\n", + "gs_optomizer.best_params_" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Обучение модели с новыми гиперпараметрами" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "metadata": {}, + "outputs": [], + "source": [ + "optimized_model = ensemble.RandomForestClassifier(\n", + " random_state=random_state,\n", + " criterion=\"gini\",\n", + " max_depth=2,\n", + " max_features=\"sqrt\",\n", + " n_estimators=10,\n", + ")\n", + "\n", + "result = {}\n", + "\n", + "result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n", + "result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n", + "result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n", + "result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n", + "\n", + "result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n", + "result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n", + "result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n", + "result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n", + "result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n", + "result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n", + "result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n", + "result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n", + "result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n", + "result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n", + "result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n", + "result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Формирование данных для оценки старой и новой версии модели и сама оценка данных" + ] + }, + { + "cell_type": "code", + "execution_count": 124, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
Name
Old1.01.01.01.01.01.0[1.0, 1.0][1.0, 1.0]
New1.01.01.01.01.01.01.01.0
\n", + "
" + ], + "text/plain": [ + " Precision_train Precision_test Recall_train Recall_test Accuracy_train \\\n", + "Name \n", + "Old 1.0 1.0 1.0 1.0 1.0 \n", + "New 1.0 1.0 1.0 1.0 1.0 \n", + "\n", + " Accuracy_test F1_train F1_test \n", + "Name \n", + "Old 1.0 [1.0, 1.0] [1.0, 1.0] \n", + "New 1.0 1.0 1.0 " + ] + }, + "execution_count": 124, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n", + "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", + " data=class_models[optimized_model_type]\n", + ")\n", + "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", + " data=result\n", + ")\n", + "optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n", + "optimized_metrics = optimized_metrics.set_index(\"Name\")\n", + "\n", + "optimized_metrics[\n", + " [\n", + " \"Precision_train\",\n", + " \"Precision_test\",\n", + " \"Recall_train\",\n", + " \"Recall_test\",\n", + " \"Accuracy_train\",\n", + " \"Accuracy_test\",\n", + " \"F1_train\",\n", + " \"F1_test\",\n", + " ]\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
Name
Old1.0[1.0, 1.0]1.01.01.0
New1.01.01.01.01.0
\n", + "
" + ], + "text/plain": [ + " Accuracy_test F1_test ROC_AUC_test Cohen_kappa_test MCC_test\n", + "Name \n", + "Old 1.0 [1.0, 1.0] 1.0 1.0 1.0\n", + "New 1.0 1.0 1.0 1.0 1.0" + ] + }, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimized_metrics[\n", + " [\n", + " \"Accuracy_test\",\n", + " \"F1_test\",\n", + " \"ROC_AUC_test\",\n", + " \"Cohen_kappa_test\",\n", + " \"MCC_test\",\n", + " ]\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n", + ")\n", + "\n", + "for index in range(0, len(optimized_metrics)):\n", + " c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n", + " disp = ConfusionMatrixDisplay(\n", + " confusion_matrix=c_matrix, display_labels=[\"Low dif-ty\", \"High dif-ty\"]\n", + " ).plot(ax=ax.flat[index])\n", + "\n", + "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Модель идеально классифицировала объекты, которые относятся к \"High difficulty\" и \"Low difficulty\"." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/lab_4/requirements.txt b/lab_4/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..482bf7066078bcf01cf19630ca5993fc4a63377d GIT binary patch literal 2088 zcmZvd?@rr55XA3ur9Mg!aUirGc!5flD)j+Sf(_&kj++<)Jbc^v?f7!0RLCg4o7u;AfsY$KguJu=?vvi?nnm(sb={9Y(|CH*~rIp?`dT#LJqk5NUowJZdlRkf%EBw87&k?Um;>FjMW5BgcOYE{za5!Zf&=XDwDs@w|U` z{)^6EMh^CGRSEp4*ikOEPael;(o3M8gjh+B`^W@OFy?V51QB4`L?r6Aa%Mcn?i9re z7ZloFjud_no2~NIiu*c5Lb>^naV5=`+S}@BRYN1>L~roeYvsKf7g4Rxe8jj0cDvA^ zkIc(-6V)8FMqT7^MRV~mU%apiZSe|u8=WOLm@0jZstT%4)Ma*bp-NZ!@BC|>tre{> z=d^q9EtFrYwkr0s%Usu5F_!~pPNOya?4z~VP5 z0=|yB5vQXb^x3B8h(6AK74g1DKQ_VPr+(uUQ1)<+#ccFq$y8^dHZN|NZ76dbSmzrg?{9{gWZg=kB#Pdt(!Hv}SL7!< z(NW+1?zFy(x_kMZJAiJ|TXV^Mx)}8A6va0|@i}=jmtOigrsu^uK@_KVih!qX^gN5F zBbj~X^+R#)#mhb|noh&7IUBjGZ0XAT*|76Hb(821b1<%eE?Mm}Mcb*q{!|T9^A@UA zTB}A=MO@t9O2<5t=ip+x=ediS{@hpVcojL^S@<4Q$IB>3nF+M8^Rss48&>WNZZ-P3kiIC0gYdHZ zN*-!*&-|5X1M4XyPMM=#s@2jiJ9R$tO|LAAa3*(nKls^SE4!L%=onuOm=mRR&bl4N djn36VJOm;;Z#vAxF>0uVLfg#=GaG&J{sPG%F-rgd literal 0 HcmV?d00001 From e8a85313538cd8f8a65c3e4e7cd7a4b442b2cdf1 Mon Sep 17 00:00:00 2001 From: shirotame Date: Fri, 15 Nov 2024 16:47:21 +0400 Subject: [PATCH 2/2] 2 business goals of 2 --- lab_4/lab4.ipynb | 1603 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 1572 insertions(+), 31 deletions(-) diff --git a/lab_4/lab4.ipynb b/lab_4/lab4.ipynb index 733cd47..c06896c 100644 --- a/lab_4/lab4.ipynb +++ b/lab_4/lab4.ipynb @@ -8,12 +8,13 @@ "\n", "Датасет: Информация об онлайн обучении учеников\n", "\n", - "Бизнес-цель 1: Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения." + "## Бизнес-цель 1: \n", + "Улучшение доступа к онлайн-образованию для учеников с низким уровнем финансового обеспечения." ] }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -31,7 +32,6 @@ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", - "import seaborn as sns\n", "from typing import Tuple\n", "from pandas import DataFrame\n", "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, metrics, set_config\n", @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -96,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -880,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -934,7 +934,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -1162,7 +1162,7 @@ "[964 rows x 10 columns]" ] }, - "execution_count": 108, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -1241,7 +1241,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -1281,7 +1281,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -1362,11 +1362,13 @@ { "cell_type": "markdown", "metadata": {}, - "source": [] + "source": [ + "Матрица неточностей" + ] }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -1402,7 +1404,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1561,7 +1563,7 @@ "mlp [0.5438066465256798, 0.0] " ] }, - "execution_count": 112, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1593,7 +1595,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -1715,7 +1717,7 @@ "mlp 0.0 0.0 " ] }, - "execution_count": 113, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -1742,7 +1744,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -1770,7 +1772,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1829,7 +1831,7 @@ "Index: []" ] }, - "execution_count": 115, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1841,12 +1843,12 @@ " columns=pipeline_end.get_feature_names_out(),\n", ")\n", "\n", - "y_pred = class_models[best_model][\"preds\"]\n", + "y_new_pred = class_models[best_model][\"preds\"]\n", "\n", - "error_index = y_test[y_test[\"Access Difficulty\"] != y_pred].index.tolist()\n", + "error_index = y_test[y_test[\"Access Difficulty\"] != y_new_pred].index.tolist()\n", "display(f\"Error items count: {len(error_index)}\")\n", "\n", - "error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n", + "error_predicted = pd.Series(y_new_pred, index=y_test.index).loc[error_index]\n", "error_df = X_test.loc[error_index].copy()\n", "error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n", "error_df.sort_index()" @@ -1861,7 +1863,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -2041,9 +2043,17 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 15, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n", + " _data = np.array(data, dtype=dtype, copy=copy,\n" + ] + }, { "data": { "text/plain": [ @@ -2053,7 +2063,7 @@ " 'model__n_estimators': 10}" ] }, - "execution_count": 121, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -2087,7 +2097,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2129,7 +2139,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -2213,7 +2223,7 @@ "New 1.0 1.0 1.0 " ] }, - "execution_count": 124, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -2245,7 +2255,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -2312,7 +2322,7 @@ "New 1.0 1.0 1.0 1.0 1.0" ] }, - "execution_count": 125, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -2331,7 +2341,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -2365,6 +2375,1537 @@ "source": [ "Модель идеально классифицировала объекты, которые относятся к \"High difficulty\" и \"Low difficulty\"." ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Бизнес-цель 2: \n", + "Повышение удовлетворенности учеников онлайн-обучением на основе их устройств, типу соединения, местоположения.\n", + "\n", + "Регрессионная модель" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'X_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork Type
294SchoolPublicFemale9MobileNoTownRichMobile Data4G
876SchoolPrivateMale11MobileNoTownMidMobile Data3G
382SchoolPrivateMale11MobileNoTownMidMobile Data3G
634UniversityPublicFemale23MobileNoTownMidWifi3G
906SchoolPublicFemale11MobileNoTownMidWifi3G
.................................
1044CollegePrivateFemale18MobileNoTownMidWifi4G
1095UniversityPrivateFemale23ComputerYesTownRichWifi4G
1130SchoolPrivateMale11MobileNoTownPoorWifi4G
860UniversityPrivateMale23MobileNoTownMidMobile Data4G
1126UniversityPrivateMale23ComputerYesRuralMidMobile Data3G
\n", + "

964 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " Education Level Institution Type Gender Age Device IT Student \\\n", + "294 School Public Female 9 Mobile No \n", + "876 School Private Male 11 Mobile No \n", + "382 School Private Male 11 Mobile No \n", + "634 University Public Female 23 Mobile No \n", + "906 School Public Female 11 Mobile No \n", + "... ... ... ... ... ... ... \n", + "1044 College Private Female 18 Mobile No \n", + "1095 University Private Female 23 Computer Yes \n", + "1130 School Private Male 11 Mobile No \n", + "860 University Private Male 23 Mobile No \n", + "1126 University Private Male 23 Computer Yes \n", + "\n", + " Location Financial Condition Internet Type Network Type \n", + "294 Town Rich Mobile Data 4G \n", + "876 Town Mid Mobile Data 3G \n", + "382 Town Mid Mobile Data 3G \n", + "634 Town Mid Wifi 3G \n", + "906 Town Mid Wifi 3G \n", + "... ... ... ... ... \n", + "1044 Town Mid Wifi 4G \n", + "1095 Town Rich Wifi 4G \n", + "1130 Town Poor Wifi 4G \n", + "860 Town Mid Mobile Data 4G \n", + "1126 Rural Mid Mobile Data 3G \n", + "\n", + "[964 rows x 10 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_train'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Flexibility Level
2940
8761
3820
6340
9060
......
10441
10952
11300
8600
11260
\n", + "

964 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " Flexibility Level\n", + "294 0\n", + "876 1\n", + "382 0\n", + "634 0\n", + "906 0\n", + "... ...\n", + "1044 1\n", + "1095 2\n", + "1130 0\n", + "860 0\n", + "1126 0\n", + "\n", + "[964 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'X_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Education LevelInstitution TypeGenderAgeDeviceIT StudentLocationFinancial ConditionInternet TypeNetwork Type
101SchoolPrivateFemale11ComputerNoTownMidWifi4G
946CollegePrivateMale18MobileNoTownMidWifi4G
306CollegePublicMale18TabYesTownMidWifi4G
109UniversityPrivateFemale23MobileNoTownMidWifi3G
1061UniversityPrivateMale23ComputerYesRuralMidMobile Data3G
.................................
908SchoolPrivateMale10MobileNoTownRichWifi4G
1135UniversityPrivateFemale18ComputerYesTownMidWifi4G
894SchoolPrivateFemale10MobileNoTownPoorMobile Data3G
866SchoolPrivateMale11MobileNoTownMidMobile Data3G
1006UniversityPrivateFemale23ComputerNoTownRichWifi4G
\n", + "

241 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " Education Level Institution Type Gender Age Device IT Student \\\n", + "101 School Private Female 11 Computer No \n", + "946 College Private Male 18 Mobile No \n", + "306 College Public Male 18 Tab Yes \n", + "109 University Private Female 23 Mobile No \n", + "1061 University Private Male 23 Computer Yes \n", + "... ... ... ... ... ... ... \n", + "908 School Private Male 10 Mobile No \n", + "1135 University Private Female 18 Computer Yes \n", + "894 School Private Female 10 Mobile No \n", + "866 School Private Male 11 Mobile No \n", + "1006 University Private Female 23 Computer No \n", + "\n", + " Location Financial Condition Internet Type Network Type \n", + "101 Town Mid Wifi 4G \n", + "946 Town Mid Wifi 4G \n", + "306 Town Mid Wifi 4G \n", + "109 Town Mid Wifi 3G \n", + "1061 Rural Mid Mobile Data 3G \n", + "... ... ... ... ... \n", + "908 Town Rich Wifi 4G \n", + "1135 Town Mid Wifi 4G \n", + "894 Town Poor Mobile Data 3G \n", + "866 Town Mid Mobile Data 3G \n", + "1006 Town Rich Wifi 4G \n", + "\n", + "[241 rows x 10 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "'y_test'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Flexibility Level
1011
9461
3061
1092
10611
......
9081
11351
8940
8660
10061
\n", + "

241 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " Flexibility Level\n", + "101 1\n", + "946 1\n", + "306 1\n", + "109 2\n", + "1061 1\n", + "... ...\n", + "908 1\n", + "1135 1\n", + "894 0\n", + "866 0\n", + "1006 1\n", + "\n", + "[241 rows x 1 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import math\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn.ensemble import RandomForestRegressor\n", + "from sklearn.preprocessing import PolynomialFeatures\n", + "from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n", + "\n", + "random_state = 9\n", + "map_flexibility_to_int = {'Low': 0, 'Moderate': 1, 'High': 2}\n", + "\n", + "df = pd.read_csv(\"..\\\\static\\\\csv\\\\students_adaptability_level_online_education.csv\")\n", + "\n", + "df['Flexibility Level'] = df['Flexibility Level'].map(map_flexibility_to_int).astype('int32')\n", + "\n", + "def split_into_train_test(\n", + " df_input: DataFrame,\n", + " target_colname: str,\n", + " frac_train: float = 0.8,\n", + " random_state: int = None,\n", + ") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n", + " \n", + " if not (0 < frac_train < 1):\n", + " raise ValueError(\"Fraction must be between 0 and 1.\")\n", + " \n", + " if target_colname not in df_input.columns:\n", + " raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n", + " \n", + " X = df_input.drop(columns=[target_colname])\n", + " y = df_input[[target_colname]]\n", + "\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y,\n", + " test_size=(1.0 - frac_train),\n", + " random_state=random_state\n", + " )\n", + " return X_train, X_test, y_train, y_test\n", + "\n", + "X_train, X_test, y_train, y_test = split_into_train_test(\n", + " df, \n", + " target_colname=\"Flexibility Level\", \n", + " frac_train=0.8, \n", + " random_state=42\n", + ")\n", + "\n", + "display(\"X_train\", X_train)\n", + "display(\"y_train\", y_train)\n", + "\n", + "display(\"X_test\", X_test)\n", + "display(\"y_test\", y_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Выполним one-hot encoding, чтобы избавиться от категориальных признаков." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
AgeEducation Level_SchoolEducation Level_UniversityInstitution Type_PublicGender_MaleDevice_MobileDevice_TabIT Student_YesLocation_TownFinancial Condition_PoorFinancial Condition_RichInternet Type_WifiNetwork Type_3GNetwork Type_4G
2949TrueFalseTrueFalseTrueFalseFalseTrueFalseTrueFalseFalseTrue
87611TrueFalseFalseTrueTrueFalseFalseTrueFalseFalseFalseTrueFalse
38211TrueFalseFalseTrueTrueFalseFalseTrueFalseFalseFalseTrueFalse
63423FalseTrueTrueFalseTrueFalseFalseTrueFalseFalseTrueTrueFalse
90611TrueFalseTrueFalseTrueFalseFalseTrueFalseFalseTrueTrueFalse
.............................................
104418FalseFalseFalseFalseTrueFalseFalseTrueFalseFalseTrueFalseTrue
109523FalseTrueFalseFalseFalseFalseTrueTrueFalseTrueTrueFalseTrue
113011TrueFalseFalseTrueTrueFalseFalseTrueTrueFalseTrueFalseTrue
86023FalseTrueFalseTrueTrueFalseFalseTrueFalseFalseFalseFalseTrue
112623FalseTrueFalseTrueFalseFalseTrueFalseFalseFalseFalseTrueFalse
\n", + "

964 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " Age Education Level_School Education Level_University \\\n", + "294 9 True False \n", + "876 11 True False \n", + "382 11 True False \n", + "634 23 False True \n", + "906 11 True False \n", + "... ... ... ... \n", + "1044 18 False False \n", + "1095 23 False True \n", + "1130 11 True False \n", + "860 23 False True \n", + "1126 23 False True \n", + "\n", + " Institution Type_Public Gender_Male Device_Mobile Device_Tab \\\n", + "294 True False True False \n", + "876 False True True False \n", + "382 False True True False \n", + "634 True False True False \n", + "906 True False True False \n", + "... ... ... ... ... \n", + "1044 False False True False \n", + "1095 False False False False \n", + "1130 False True True False \n", + "860 False True True False \n", + "1126 False True False False \n", + "\n", + " IT Student_Yes Location_Town Financial Condition_Poor \\\n", + "294 False True False \n", + "876 False True False \n", + "382 False True False \n", + "634 False True False \n", + "906 False True False \n", + "... ... ... ... \n", + "1044 False True False \n", + "1095 True True False \n", + "1130 False True True \n", + "860 False True False \n", + "1126 True False False \n", + "\n", + " Financial Condition_Rich Internet Type_Wifi Network Type_3G \\\n", + "294 True False False \n", + "876 False False True \n", + "382 False False True \n", + "634 False True True \n", + "906 False True True \n", + "... ... ... ... \n", + "1044 False True False \n", + "1095 True True False \n", + "1130 False True False \n", + "860 False False False \n", + "1126 False False True \n", + "\n", + " Network Type_4G \n", + "294 True \n", + "876 False \n", + "382 False \n", + "634 False \n", + "906 False \n", + "... ... \n", + "1044 True \n", + "1095 True \n", + "1130 True \n", + "860 True \n", + "1126 False \n", + "\n", + "[964 rows x 14 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cat_features = ['Education Level', 'Institution Type', 'Gender', 'Device', 'IT Student', 'Location', 'Financial Condition', 'Internet Type', 'Network Type']\n", + "\n", + "X_test = pd.get_dummies(X_test, columns=cat_features, drop_first=True)\n", + "X_train = pd.get_dummies(X_train, columns=cat_features, drop_first=True)\n", + "\n", + "X_test\n", + "X_train" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Определение перечня алгоритмов решения задачи регрессии." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: linear\n", + "Model: linear_poly\n", + "Model: linear_interact\n", + "Model: ridge\n", + "Model: decision_tree\n", + "Model: knn\n", + "Model: random_forest\n", + "Model: mlp\n" + ] + } + ], + "source": [ + "models = {\n", + " \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n", + " \"linear_poly\": {\n", + " \"model\": make_pipeline(\n", + " PolynomialFeatures(degree=2),\n", + " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", + " )\n", + " },\n", + " \"linear_interact\": {\n", + " \"model\": make_pipeline(\n", + " PolynomialFeatures(interaction_only=True),\n", + " linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n", + " )\n", + " },\n", + " \"ridge\": {\"model\": linear_model.RidgeCV()},\n", + " \"decision_tree\": {\n", + " \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n", + " },\n", + " \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n", + " \"random_forest\": {\n", + " \"model\": ensemble.RandomForestRegressor(\n", + " max_depth=7, random_state=random_state, n_jobs=-1\n", + " )\n", + " },\n", + " \"mlp\": {\n", + " \"model\": neural_network.MLPRegressor(\n", + " activation=\"tanh\",\n", + " hidden_layer_sizes=(3),\n", + " max_iter=500,\n", + " early_stopping=True,\n", + " random_state=random_state,\n", + " )\n", + " },\n", + "}\n", + "\n", + "for model_name in models.keys():\n", + " print(f\"Model: {model_name}\")\n", + "\n", + " fitted_model = models[model_name][\"model\"].fit(\n", + " X_train.values, y_train.values.ravel()\n", + " )\n", + " y_train_pred = fitted_model.predict(X_train.values)\n", + " y_test_pred = fitted_model.predict(X_test.values)\n", + " models[model_name][\"fitted\"] = fitted_model\n", + " models[model_name][\"train_preds\"] = y_train_pred\n", + " models[model_name][\"preds\"] = y_test_pred\n", + " models[model_name][\"RMSE_train\"] = math.sqrt(\n", + " metrics.mean_squared_error(y_train, y_train_pred)\n", + " )\n", + " models[model_name][\"RMSE_test\"] = math.sqrt(\n", + " metrics.mean_squared_error(y_test, y_test_pred)\n", + " )\n", + " models[model_name][\"RMAE_test\"] = math.sqrt(\n", + " metrics.mean_absolute_error(y_test, y_test_pred)\n", + " )\n", + " models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_pred)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Выводим результаты оценки." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
 RMSE_trainRMSE_testRMAE_testR2_test
random_forest0.3839130.4154420.5649530.581728
knn0.4026960.4600200.5828000.487148
decision_tree0.4310060.4658110.5824630.474156
linear_interact0.4379740.4768280.6042170.448987
linear_poly0.4371460.4769200.6052060.448773
ridge0.5366850.5644210.6822690.227951
linear0.5366520.5648340.6828420.226821
mlp0.5827200.6209610.7278960.065525
\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n", + " [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n", + "]\n", + "reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n", + " cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n", + ").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Выводим лучшую модель." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'random_forest'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n", + "\n", + "display(best_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Подбираем гиперпараметры методом поиска по сетке." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 36 candidates, totalling 180 fits\n", + "Лучшие параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n", + "Лучший результат (MSE): 0.15015918754440927\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n" + ] + } + ], + "source": [ + "X = df[['Device', 'Financial Condition', 'Internet Type']]\n", + "y = df['Flexibility Level'] # Целевая переменная для регрессии\n", + "\n", + "model = RandomForestRegressor() \n", + "\n", + "param_grid = {\n", + " 'n_estimators': [50, 100, 200], \n", + " 'max_depth': [None, 10, 20, 30], \n", + " 'min_samples_split': [2, 5, 10] \n", + "}\n", + "\n", + "grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n", + " scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n", + "\n", + "grid_search.fit(X_train, y_train)\n", + "\n", + "print(\"Лучшие параметры:\", grid_search.best_params_)\n", + "print(\"Лучший результат (MSE):\", -grid_search.best_score_)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Обучаем модель с новыми гиперпараметрами и сравниваем новых данных со старыми." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fitting 5 folds for each of 36 candidates, totalling 180 fits\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n", + "d:\\ulstu\\cr3\\sem1\\MAI\\AIM-PIbd-31-Makarov-DV\\.venv\\Lib\\site-packages\\sklearn\\base.py:1473: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples,), for example using ravel().\n", + " return fit_method(estimator, *args, **kwargs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Старые параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 200}\n", + "Лучший результат (MSE) на старых параметрах: 0.14998947697586934\n", + "\n", + "Новые параметры: {'max_depth': 30, 'min_samples_split': 2, 'n_estimators': 50}\n", + "Лучший результат (MSE) на новых параметрах: 0.18737177399159283\n", + "Среднеквадратическая ошибка (MSE) на тестовых данных: 0.13671335461532685\n", + "Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.3697476904800446\n" + ] + } + ], + "source": [ + "# Old data\n", + "\n", + "old_param_grid = param_grid\n", + "old_grid_search = grid_search\n", + "old_grid_search.fit(X_train, y_train)\n", + "\n", + "old_best_params = old_grid_search.best_params_\n", + "old_best_mse = -old_grid_search.best_score_ \n", + "\n", + "# New data\n", + "\n", + "new_param_grid = {\n", + " 'n_estimators': [50],\n", + " 'max_depth': [30],\n", + " 'min_samples_split': [2]\n", + " }\n", + "new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n", + " param_grid=new_param_grid,\n", + " scoring='neg_mean_squared_error', cv=2)\n", + "\n", + "new_grid_search.fit(X_train, y_train)\n", + "\n", + "new_best_params = new_grid_search.best_params_\n", + "new_best_mse = -new_grid_search.best_score_\n", + "\n", + "new_best_model = RandomForestRegressor(**new_best_params)\n", + "new_best_model.fit(X_train, y_train)\n", + "\n", + "old_best_model = RandomForestRegressor(**old_best_params)\n", + "old_best_model.fit(X_train, y_train)\n", + "\n", + "y_new_pred = new_best_model.predict(X_test)\n", + "y_old_pred = old_best_model.predict(X_test)\n", + "\n", + "mse = metrics.mean_squared_error(y_test, y_new_pred)\n", + "rmse = np.sqrt(mse)\n", + "\n", + "print(\"Старые параметры:\", old_best_params)\n", + "print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n", + "print(\"\\nНовые параметры:\", new_best_params)\n", + "print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n", + "print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n", + "print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Визуализация данных" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(10, 6))\n", + "plt.plot(y_test.values, label='Истинные значения', color='blue', linewidth=2)\n", + "plt.plot(y_old_pred, label='Предсказанные значения (старые данные)', color='red', linestyle='--', linewidth=2)\n", + "plt.plot(y_new_pred, label='Предсказанные значения (новые данные)', color='green', linestyle='-', linewidth=2)\n", + "\n", + "plt.title('Сравнение предсказанных и истинных значений')\n", + "plt.xlabel('Подбор параметров')\n", + "plt.ylabel('Значения')\n", + "plt.grid()\n", + "plt.legend(loc ='lower right')\n", + "plt.show()" + ] } ], "metadata": {