AIM-PIbd-32-Isaeva-A-I/lab_4/Lab4.ipynb
2024-12-21 00:25:04 +04:00

161 KiB
Raw Blame History

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter

df = pd.read_csv(".//csv//Student Depression Dataset.csv")
print(df.columns)
Index(['id', 'Gender', 'Age', 'City', 'Profession', 'Academic Pressure',
       'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction',
       'Sleep Duration', 'Dietary Habits', 'Degree',
       'Have you ever had suicidal thoughts ?', 'Work/Study Hours',
       'Financial Stress', 'Family History of Mental Illness', 'Depression'],
      dtype='object')
In [3]:
print(df.head())
   id  Gender   Age           City Profession  Academic Pressure  \
0   2    Male  33.0  Visakhapatnam    Student                5.0   
1   8  Female  24.0      Bangalore    Student                2.0   
2  26    Male  31.0       Srinagar    Student                3.0   
3  30  Female  28.0       Varanasi    Student                3.0   
4  32  Female  25.0         Jaipur    Student                4.0   

   Work Pressure  CGPA  Study Satisfaction  Job Satisfaction  \
0            0.0  8.97                 2.0               0.0   
1            0.0  5.90                 5.0               0.0   
2            0.0  7.03                 5.0               0.0   
3            0.0  5.59                 2.0               0.0   
4            0.0  8.13                 3.0               0.0   

      Sleep Duration Dietary Habits   Degree  \
0          5-6 hours        Healthy  B.Pharm   
1          5-6 hours       Moderate      BSc   
2  Less than 5 hours        Healthy       BA   
3          7-8 hours       Moderate      BCA   
4          5-6 hours       Moderate   M.Tech   

  Have you ever had suicidal thoughts ?  Work/Study Hours  Financial Stress  \
0                                   Yes               3.0               1.0   
1                                    No               3.0               2.0   
2                                    No               9.0               1.0   
3                                   Yes               4.0               5.0   
4                                   Yes               1.0               1.0   

  Family History of Mental Illness  Depression  
0                               No           1  
1                              Yes           0  
2                              Yes           0  
3                              Yes           1  
4                               No           0  

Бизнес-цель исследования

Разработать и внедрить систему прогнозирования уровня депрессии среди обучающихся, которая позволит выявить группы риска на ранних этапах. Результаты исследования могут быть полезны психологам, педагогам и администрации учебных заведений.

Описание набора данных для анализа

Набор данных содержит информацию о психологическом состоянии обучающихся и включает следующие поля:

  • id идентификатор, число
  • Gender пол, строка
  • Age возраст, дробное число
  • City город, строка
  • Profession профессия, строка
  • Academic Pressure академическое давление, дробное число (от 1.00 до 5.00)
  • Work Pressure рабочее давление, дробное число (от 1.00 до 5.00)
  • CGPA средний балл (GPA), дробное число
  • Study Satisfaction удовлетворенность учебой, дробное число (от 1.00 до 5.00)
  • Job Satisfaction удовлетворенность работой, дробное число (от 1.00 до 5.00)
  • Sleep Duration продолжительность сна, строка
  • Dietary Habits пищевые привычки, строка
  • Degree степень (образование), строка
  • Have you ever had suicidal thoughts? Были ли у вас когда-либо суицидальные мысли? строка (yes/no)
  • Work/Study Hours часы работы/учебы, дробное число
  • Financial Stress финансовый стресс, дробное число (от 1.00 до 5.00)
  • Family History of Mental Illness семейный анамнез психических заболеваний, строка (yes/no)
  • Depression депрессия, булевое значение (1/0)

Обработка данных

In [4]:
df.isnull().sum()
Out[4]:
id                                       0
Gender                                   0
Age                                      0
City                                     0
Profession                               0
Academic Pressure                        0
Work Pressure                            0
CGPA                                     0
Study Satisfaction                       0
Job Satisfaction                         0
Sleep Duration                           0
Dietary Habits                           0
Degree                                   0
Have you ever had suicidal thoughts ?    0
Work/Study Hours                         0
Financial Stress                         3
Family History of Mental Illness         0
Depression                               0
dtype: int64
In [5]:
df.dropna(subset=['Financial Stress'], inplace=True)
In [6]:
features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', 
                      'Job Satisfaction', 'Work/Study Hours', 'Financial Stress', 'Depression']

plt.figure(figsize=(15, 10))
for i, feature in enumerate(features, 1):
    plt.subplot(3, 3, i)
    sns.boxplot(y=df[feature], color='skyblue')
    plt.title(f'Boxplot of {feature}')
    plt.ylabel(feature)

plt.tight_layout()
plt.show()
No description has been provided for this image

В Age много выбросов. Сбалансируем данные

In [7]:
Q1 = df['Age'].quantile(0.25)
Q3 = df['Age'].quantile(0.75)
IQR = Q3 - Q1

threshold = 1.5 * IQR
outliers = (df['Age'] < (Q1 - threshold)) | (df['Age'] > (Q3 + threshold))

median_rating = df['Age'].median()
df.loc[outliers, 'Age'] = median_rating

plt.figure(figsize=(8, 6))
sns.boxplot(y=df['Age'], color='skyblue')
plt.title('Boxplot of Age')
plt.ylabel('Age')
plt.show()
No description has been provided for this image

Конструирование признаков с помощью меток

In [8]:
from sklearn.preprocessing import LabelEncoder

le = LabelEncoder()
df['Gender'] = le.fit_transform(df['Gender'])
df['City'] = le.fit_transform(df['City'])
df['Dietary Habits'] = le.fit_transform(df['Dietary Habits'])
df['Degree'] = le.fit_transform(df['Degree'])
df['Have you ever had suicidal thoughts ?'] = le.fit_transform(df['Have you ever had suicidal thoughts ?'])
df['Sleep Duration'] = le.fit_transform(df['Sleep Duration'])
df['Profession'] = le.fit_transform(df['Profession'])
df['Study Satisfaction'] = le.fit_transform(df['Study Satisfaction'])
df['Family History of Mental Illness'] = le.fit_transform(df['Family History of Mental Illness'])

разделение на признаки и целевую переменную

In [9]:
x = df.drop('Depression', axis=1)
y = df['Depression']
In [10]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

1) Метод регрессии Лассо

In [15]:
from sklearn.linear_model import Lasso

param_grid_lasso = {
    'alpha': [0.01, 0.1, 1.0, 10.0],
    'fit_intercept': [True, False],
}

# Создание объекта GridSearchCV
grid_search_lasso = GridSearchCV(
    estimator=Lasso(),  
    param_grid=param_grid_lasso,  
    cv=5, 
    scoring='neg_mean_squared_error', 
    n_jobs=-1 
)

grid_search_lasso.fit(x_train, y_train)

# Вывод лучших гиперпараметров
print("Лучшие гиперпараметры для Lasso:")
print(grid_search_lasso.best_params_)
Лучшие гиперпараметры для Lasso:
{'alpha': 0.01, 'fit_intercept': False}

2) Метод градиентного бустинга

In [14]:
from sklearn.ensemble import GradientBoostingRegressor

param_grid_gb = {
    'n_estimators': [50, 100, 200],
    'learning_rate': [0.01, 0.1, 0.2],
    'max_depth': [3, 5, 7],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['auto', 'sqrt', 'log2']
}

grid_search_gb = GridSearchCV(
    estimator=GradientBoostingRegressor(),
    param_grid=param_grid_gb,
    cv=5,
    scoring='neg_mean_squared_error',
    n_jobs=-1
)

grid_search_gb.fit(x_train, y_train)

# Вывод лучших гиперпараметров
print("Лучшие гиперпараметры для Gradient Boosting:")
print(grid_search_gb.best_params_)
e:\AIM1.5\Scripts\Lib\site-packages\sklearn\model_selection\_validation.py:540: FitFailedWarning: 
1215 fits failed out of a total of 3645.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
978 fits failed with the following error:
Traceback (most recent call last):
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\model_selection\_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\base.py", line 1466, in wrapper
    estimator._validate_params()
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\utils\_param_validation.py", line 95, in validate_parameter_constraints
    raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.

--------------------------------------------------------------------------------
237 fits failed with the following error:
Traceback (most recent call last):
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\model_selection\_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\base.py", line 1466, in wrapper
    estimator._validate_params()
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "e:\AIM1.5\Scripts\Lib\site-packages\sklearn\utils\_param_validation.py", line 95, in validate_parameter_constraints
    raise InvalidParameterError(
sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.

  warnings.warn(some_fits_failed_message, FitFailedWarning)
e:\AIM1.5\Scripts\Lib\site-packages\numpy\ma\core.py:2881: RuntimeWarning: invalid value encountered in cast
  _data = np.array(data, dtype=dtype, copy=copy,
e:\AIM1.5\Scripts\Lib\site-packages\sklearn\model_selection\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [        nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan -0.18767441 -0.15799837 -0.13080278
 -0.18762913 -0.15792709 -0.13056114 -0.18792038 -0.15737146 -0.130218
 -0.18725961 -0.157967   -0.13047453 -0.18766583 -0.15779565 -0.13094863
 -0.18798705 -0.15693978 -0.13061215 -0.18766317 -0.15746848 -0.13072918
 -0.18864158 -0.15666133 -0.13095037 -0.18817206 -0.15805489 -0.13086126
 -0.18707465 -0.15864932 -0.13104947 -0.18818902 -0.15828572 -0.13063871
 -0.18701628 -0.15853864 -0.13019458 -0.18740927 -0.15836397 -0.13065455
 -0.18768748 -0.15828297 -0.1309458  -0.18845004 -0.15696395 -0.13023062
 -0.18754854 -0.15899615 -0.13061707 -0.18831427 -0.15819939 -0.13096524
 -0.18662963 -0.15815869 -0.13089186         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
 -0.1758914  -0.1442684  -0.12093344 -0.1758927  -0.14423731 -0.12084543
 -0.17573339 -0.14419842 -0.12076166 -0.17512045 -0.14435454 -0.1207299
 -0.17669645 -0.14397965 -0.12087019 -0.17605424 -0.1438664  -0.12091068
 -0.17582192 -0.1443651  -0.12097165 -0.17588422 -0.14421003 -0.12081764
 -0.17522742 -0.14424357 -0.12086484 -0.17530986 -0.14433713 -0.12091757
 -0.17565647 -0.14408902 -0.12075918 -0.17561884 -0.14426355 -0.12094066
 -0.17522371 -0.1439869  -0.12099023 -0.17619772 -0.14396131 -0.12079667
 -0.17710789 -0.1448419  -0.12087822 -0.17608534 -0.14416684 -0.12087865
 -0.1754675  -0.1442258  -0.12068226 -0.17611334 -0.14433552 -0.12093556
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan -0.16938321 -0.13763002 -0.11703902
 -0.16953091 -0.13736586 -0.11695779 -0.16881837 -0.1375676  -0.11694438
 -0.16927898 -0.13748177 -0.11689982 -0.16921265 -0.13757375 -0.11682524
 -0.16915872 -0.13727377 -0.11694336 -0.16939766 -0.13734972 -0.1167447
 -0.16924214 -0.1373768  -0.11674816 -0.16918278 -0.13746085 -0.1169816
 -0.16927003 -0.13740063 -0.1169564  -0.16916501 -0.13752074 -0.11687641
 -0.16928973 -0.13751536 -0.11697948 -0.16934836 -0.13727436 -0.11693615
 -0.16912453 -0.13748699 -0.11693425 -0.1692788  -0.13750784 -0.11694655
 -0.16919354 -0.13747437 -0.11708782 -0.16940009 -0.13757749 -0.11700586
 -0.1692801  -0.13725384 -0.11684394         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
 -0.11606052 -0.1140225  -0.11403709 -0.11627212 -0.1139982  -0.11402075
 -0.11613561 -0.11407941 -0.11420487 -0.11666225 -0.11462523 -0.11431901
 -0.11604817 -0.11456211 -0.11392092 -0.11609343 -0.11394228 -0.11414071
 -0.11611685 -0.11420178 -0.11405459 -0.11594404 -0.11408614 -0.11391662
 -0.11590886 -0.11396465 -0.11389125 -0.11616694 -0.11441846 -0.11417015
 -0.11617368 -0.11429765 -0.1139636  -0.11616763 -0.11433984 -0.11412121
 -0.11625618 -0.11402999 -0.11419791 -0.11613603 -0.114206   -0.11423922
 -0.1160801  -0.11431896 -0.11416734 -0.11608923 -0.11455498 -0.11417448
 -0.11605165 -0.11427773 -0.11392205 -0.11606243 -0.11408421 -0.11395292
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan -0.11281447 -0.11245904 -0.11308822
 -0.11256366 -0.11230094 -0.1130767  -0.11282651 -0.1121034  -0.11283479
 -0.11260704 -0.1125136  -0.11288977 -0.11278304 -0.11242278 -0.11268564
 -0.11263359 -0.11236227 -0.11329411 -0.11231603 -0.1124533  -0.11278826
 -0.11291545 -0.11241223 -0.11250702 -0.11246481 -0.11228665 -0.11348916
 -0.11250694 -0.11250274 -0.11298019 -0.11277323 -0.11248601 -0.11301753
 -0.11259486 -0.1124685  -0.11285441 -0.11274424 -0.11232891 -0.11316456
 -0.11274575 -0.11256149 -0.11252293 -0.11293524 -0.11261757 -0.11305628
 -0.11253063 -0.11237109 -0.11278518 -0.1124074  -0.11276905 -0.11296684
 -0.11258689 -0.11228467 -0.11331342         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
 -0.11292265 -0.11395193 -0.11564599 -0.11244356 -0.11338947 -0.1148266
 -0.11295702 -0.11353862 -0.11510521 -0.11244347 -0.11387967 -0.11512396
 -0.11269802 -0.11364442 -0.1151339  -0.11238356 -0.11364301 -0.11496543
 -0.11229193 -0.11340926 -0.11550744 -0.11215818 -0.11367944 -0.11552889
 -0.11240305 -0.11352309 -0.115412   -0.1128402  -0.11338749 -0.1153551
 -0.11250042 -0.11347275 -0.11548445 -0.11271132 -0.11377527 -0.11558066
 -0.11318598 -0.11325792 -0.11499103 -0.11253099 -0.1129829  -0.11530949
 -0.11239074 -0.11329625 -0.11544761 -0.11262484 -0.11323392 -0.1151936
 -0.11253889 -0.11382403 -0.11511129 -0.11250854 -0.11339898 -0.11536332
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan -0.11542253 -0.11498664 -0.11428517
 -0.11503783 -0.11473447 -0.11458687 -0.11483866 -0.1154254  -0.11479037
 -0.11533015 -0.11515195 -0.11460571 -0.11563491 -0.11433835 -0.11437413
 -0.11510849 -0.11472156 -0.11516494 -0.11545009 -0.115001   -0.11479743
 -0.11461761 -0.11537461 -0.11497109 -0.1155148  -0.11567353 -0.11431184
 -0.11546067 -0.11462564 -0.11450721 -0.11511    -0.11487988 -0.11466523
 -0.11585756 -0.11462611 -0.11433121 -0.11538152 -0.11463425 -0.11527088
 -0.11509145 -0.11493588 -0.11484324 -0.11528905 -0.11426327 -0.11476508
 -0.11499562 -0.11451299 -0.11466765 -0.11525918 -0.11469718 -0.11476983
 -0.11467865 -0.1145067  -0.11479425         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
 -0.11352917 -0.1145882  -0.11643688 -0.11418115 -0.11442858 -0.11635549
 -0.11408502 -0.11458383 -0.1163013  -0.1135842  -0.11453566 -0.11575264
 -0.11341863 -0.11481638 -0.11635685 -0.1132144  -0.11438018 -0.11666005
 -0.11311482 -0.11500883 -0.11594984 -0.11409228 -0.11464061 -0.1158012
 -0.11389399 -0.11454081 -0.1157428  -0.11333869 -0.11438896 -0.11676006
 -0.11382523 -0.11443669 -0.11606569 -0.11424726 -0.11464652 -0.11608159
 -0.11396605 -0.11473188 -0.1167532  -0.1136805  -0.11455875 -0.11615814
 -0.11372286 -0.11442829 -0.11590895 -0.1136509  -0.11368863 -0.11660073
 -0.1136605  -0.1141187  -0.11613806 -0.11326355 -0.11427399 -0.11676148
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan -0.11573534 -0.11897501 -0.1226239
 -0.1162633  -0.11939573 -0.12255715 -0.11636411 -0.11878021 -0.12306277
 -0.11535113 -0.11813967 -0.1230085  -0.11594119 -0.11812955 -0.12217928
 -0.11523023 -0.11843291 -0.12228252 -0.1159457  -0.11840108 -0.12181337
 -0.11600134 -0.11790484 -0.12203724 -0.11579998 -0.11787918 -0.12317219
 -0.11578704 -0.11837798 -0.12379234 -0.1155279  -0.11865384 -0.12319867
 -0.11597008 -0.11886814 -0.12291788 -0.1162282  -0.11918752 -0.12363613
 -0.11571473 -0.11805225 -0.12250506 -0.11640247 -0.11823175 -0.1226976
 -0.11571549 -0.11813327 -0.12229009 -0.11621545 -0.11793769 -0.1229533
 -0.11528287 -0.1183919  -0.12121653]
  warnings.warn(
Лучшие гиперпараметры для Gradient Boosting:
{'learning_rate': 0.1, 'max_depth': 5, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 100}

3) Метод k-ближайших соседей

In [16]:
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV

param_grid_knn = {
    'n_neighbors': [3, 5, 7, 10],
    'weights': ['uniform', 'distance'],
    'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],
    'p': [1, 2]
}

grid_search_knn = GridSearchCV(
    estimator=KNeighborsRegressor(),
    param_grid=param_grid_knn,
    cv=5,
    scoring='neg_mean_squared_error',
    n_jobs=-1
)

grid_search_knn.fit(x_train, y_train)

# Вывод лучших гиперпараметров
print("Лучшие гиперпараметры для k-Nearest Neighbors:")
print(grid_search_knn.best_params_)
Лучшие гиперпараметры для k-Nearest Neighbors:
{'algorithm': 'ball_tree', 'n_neighbors': 10, 'p': 1, 'weights': 'distance'}

Предсказание на тестовой выборке

In [128]:
y_pred = model.predict(x_test)
y_pred_forest = model_forest.predict(x_test)
y_pred_lasso = model_lasso.predict(x_test)
y_pred_gb = model_gb.predict(x_test)
y_pred_neighbors = model_knn.predict(x_test)

Оценка качества модели

  1. MSE (Mean Squared Error) Среднее значение квадратов разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель.
In [156]:
from sklearn.metrics import mean_squared_error
import numpy as np

mse1 = mean_squared_error(y_test, y_pred)
mse2 = mean_squared_error(y_test, y_pred_forest)
mse3 = mean_squared_error(y_test, y_pred_lasso)
mse4 = mean_squared_error(y_test, y_pred_gb)
mse5 = mean_squared_error(y_test, y_pred_neighbors)

mse1_rounded = round(mse1, 3)
mse2_rounded = round(mse2, 3)
mse3_rounded = round(mse3, 3)
mse4_rounded = round(mse4, 3)
mse5_rounded = round(mse5, 3)

print("Mean Squared Error (MSE):")
print(f"k-NN: \t\t\t{mse1_rounded}")
print(f"Random Forest: \t\t{mse2_rounded}")
print(f"Lasso: \t\t\t{mse3_rounded}")
print(f"Gradient Boosting: \t{mse4_rounded}")
print(f"k-Nearest Neighbors: \t{mse5_rounded}")
Mean Squared Error (MSE):
k-NN: 			0.213
Random Forest: 		0.118
Lasso: 			0.166
Gradient Boosting: 	0.113
k-Nearest Neighbors: 	0.326
  1. MAE Среднее значение абсолютных разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель.
In [155]:
from sklearn.metrics import mean_absolute_error

mae1 = round(mean_absolute_error(y_test, y_pred),3)
mae2 = round(mean_absolute_error(y_test, y_pred_forest),3)
mae3 = round(mean_absolute_error(y_test, y_pred_lasso),3)
mae4 = round(mean_absolute_error(y_test, y_pred_gb),3)
mae5 = round(mean_absolute_error(y_test, y_pred_neighbors),3)
print("Mean Absolute Error (MAE):")
print(f"k-NN: \t\t\t{mae1}")
print(f"Random Forest: \t\t{mae2}")
print(f"Lasso: \t\t\t{mae3}")
print(f"Gradient Boosting: \t{mae4}")
print(f"k-Nearest Neighbors: \t{mae5}")
Mean Absolute Error (MAE):
k-NN: 			0.213
Random Forest: 		0.238
Lasso: 			0.366
Gradient Boosting: 	0.246
k-Nearest Neighbors: 	0.485
  1. R-squared Мера, показывающая, насколько хорошо модель объясняет изменчивость данных. Значение находится в диапазоне от 0 до 1, где 1 — идеальное соответствие, а 0 — модель не объясняет данные.
In [153]:
from sklearn.metrics import r2_score
r2 = r2_score(y_test, y_pred)
print(f"R² (R-squared): {r2}")

r2_1 = r2_score(y_test, y_pred)
r2_2 = r2_score(y_test, y_pred_forest)
r2_3 = r2_score(y_test, y_pred_lasso)
r2_4 = r2_score(y_test, y_pred_gb)
r2_5 = r2_score(y_test, y_pred_neighbors)

r2_1_rounded = round(r2_1, 3)
r2_2_rounded = round(r2_2, 3)
r2_3_rounded = round(r2_3, 3)
r2_4_rounded = round(r2_4, 3)
r2_5_rounded = round(r2_5, 3)

print("\nR² (R-squared):")
print(f"k-NN: \t\t\t{r2_1_rounded}")
print(f"Random Forest: \t\t{r2_2_rounded}")
print(f"Lasso: \t\t\t{r2_3_rounded}")
print(f"Gradient Boosting: \t{r2_4_rounded}")
print(f"k-Nearest Neighbors: \t{r2_5_rounded}")
R² (R-squared): 0.127933821917115

R² (R-squared):
k-NN: 			0.128
Random Forest: 		0.515
Lasso: 			0.319
Gradient Boosting: 	0.537
k-Nearest Neighbors: 	-0.337
  1. RMSE Среднее отклонение предсказаний от реальных данных. Чем меньше модуль, тем лучше модель.
In [151]:
rmse1 = np.sqrt(mse1)
rmse2 = np.sqrt(mse2)
rmse3 = np.sqrt(mse3)
rmse4 = np.sqrt(mse4)
rmse5 = np.sqrt(mse5)

rmse1_rounded = round(rmse1, 3)
rmse2_rounded = round(rmse2, 3)
rmse3_rounded = round(rmse3, 3)
rmse4_rounded = round(rmse4, 3)
rmse5_rounded = round(rmse5, 3)

print("Root Mean Squared Error (RMSE):")
print(f"k-NN: \t\t\t{rmse1_rounded}")
print(f"Random Forest: \t\t{rmse2_rounded}")
print(f"Lasso: \t\t\t{rmse3_rounded}")
print(f"Gradient Boosting: \t{rmse4_rounded}")
print(f"k-Nearest Neighbors: \t{rmse5_rounded}")
Root Mean Squared Error (RMSE):
k-NN: 			0.461
Random Forest: 		0.344
Lasso: 			0.407
Gradient Boosting: 	0.336
k-Nearest Neighbors: 	0.571

Лучший результат градиентный бустинг и случайный лес. Положительные результаты по всем критериям получил случайный лес. Три из четырех положительных результата у градиентного бустинга.

Значит, случайный лес наиболее точная и устойчивая стратегия обучения модели. Итоговая модель model_forest.

Также, с помощью применение важности признаков (feature importance) на Случайном лесе, мы вывели основные факторы, вызывающие депрессию:

In [19]:
from sklearn.ensemble import RandomForestRegressor

model_rf = RandomForestRegressor(n_estimators=100, random_state=42)
model_rf.fit(x_train, y_train)

feature_importances = model_rf.feature_importances_

import pandas as pd
feature_importance_df = pd.DataFrame({
    'Feature': x.columns,
    'Importance': feature_importances
}).sort_values(by='Importance', ascending=False)

print(feature_importance_df)
                                  Feature  Importance
13  Have you ever had suicidal thoughts ?    0.300542
5                       Academic Pressure    0.134276
0                                      id    0.087970
7                                    CGPA    0.079078
2                                     Age    0.066613
15                       Financial Stress    0.066330
3                                    City    0.059293
14                       Work/Study Hours    0.052275
12                                 Degree    0.049539
8                      Study Satisfaction    0.032944
11                         Dietary Habits    0.026140
10                         Sleep Duration    0.024435
16       Family History of Mental Illness    0.010547
1                                  Gender    0.009627
4                              Profession    0.000372
9                        Job Satisfaction    0.000017
6                           Work Pressure    0.000003