AIM-PIbd-32-Isaeva-A-I/lab_4/Lab4.ipynb

796 lines
150 KiB
Plaintext
Raw Normal View History

2024-12-21 00:25:04 +04:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'Gender', 'Age', 'City', 'Profession', 'Academic Pressure',\n",
" 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction',\n",
" 'Sleep Duration', 'Dietary Habits', 'Degree',\n",
" 'Have you ever had suicidal thoughts ?', 'Work/Study Hours',\n",
" 'Financial Stress', 'Family History of Mental Illness', 'Depression'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from matplotlib.ticker import FuncFormatter\n",
2024-12-21 12:33:06 +04:00
"from sklearn.pipeline import Pipeline\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
"from sklearn.model_selection import GridSearchCV\n",
"from sklearn.linear_model import Lasso\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.neighbors import KNeighborsRegressor\n",
2024-12-21 00:25:04 +04:00
"\n",
"df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id Gender Age City Profession Academic Pressure \\\n",
"0 2 Male 33.0 Visakhapatnam Student 5.0 \n",
"1 8 Female 24.0 Bangalore Student 2.0 \n",
"2 26 Male 31.0 Srinagar Student 3.0 \n",
"3 30 Female 28.0 Varanasi Student 3.0 \n",
"4 32 Female 25.0 Jaipur Student 4.0 \n",
"\n",
" Work Pressure CGPA Study Satisfaction Job Satisfaction \\\n",
"0 0.0 8.97 2.0 0.0 \n",
"1 0.0 5.90 5.0 0.0 \n",
"2 0.0 7.03 5.0 0.0 \n",
"3 0.0 5.59 2.0 0.0 \n",
"4 0.0 8.13 3.0 0.0 \n",
"\n",
" Sleep Duration Dietary Habits Degree \\\n",
"0 5-6 hours Healthy B.Pharm \n",
"1 5-6 hours Moderate BSc \n",
"2 Less than 5 hours Healthy BA \n",
"3 7-8 hours Moderate BCA \n",
"4 5-6 hours Moderate M.Tech \n",
"\n",
" Have you ever had suicidal thoughts ? Work/Study Hours Financial Stress \\\n",
"0 Yes 3.0 1.0 \n",
"1 No 3.0 2.0 \n",
"2 No 9.0 1.0 \n",
"3 Yes 4.0 5.0 \n",
"4 Yes 1.0 1.0 \n",
"\n",
" Family History of Mental Illness Depression \n",
"0 No 1 \n",
"1 Yes 0 \n",
"2 Yes 0 \n",
"3 Yes 1 \n",
"4 No 0 \n"
]
}
],
"source": [
"print(df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цель исследования\n",
"Разработать и внедрить систему прогнозирования уровня депрессии среди обучающихся, которая позволит выявить группы риска на ранних этапах. Результаты исследования могут быть полезны психологам, педагогам и администрации учебных заведений.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Описание набора данных для анализа\n",
"Набор данных содержит информацию о психологическом состоянии обучающихся и включает следующие поля:\n",
"- id идентификатор, число\n",
"- Gender пол, строка\n",
"- Age возраст, дробное число\n",
"- City город, строка\n",
"- Profession профессия, строка\n",
"- Academic Pressure академическое давление, дробное число (от 1.00 до 5.00)\n",
"- Work Pressure рабочее давление, дробное число (от 1.00 до 5.00)\n",
"- CGPA средний балл (GPA), дробное число\n",
"- Study Satisfaction удовлетворенность учебой, дробное число (от 1.00 до 5.00)\n",
"- Job Satisfaction удовлетворенность работой, дробное число (от 1.00 до 5.00)\n",
"- Sleep Duration продолжительность сна, строка\n",
"- Dietary Habits пищевые привычки, строка\n",
"- Degree степень (образование), строка\n",
"- Have you ever had suicidal thoughts? Были ли у вас когда-либо суицидальные мысли? строка (yes/no)\n",
"- Work/Study Hours часы работы/учебы, дробное число\n",
"- Financial Stress финансовый стресс, дробное число (от 1.00 до 5.00)\n",
"- Family History of Mental Illness семейный анамнез психических заболеваний, строка (yes/no)\n",
"- Depression депрессия, булевое значение (1/0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обработка данных"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"id 0\n",
"Gender 0\n",
"Age 0\n",
"City 0\n",
"Profession 0\n",
"Academic Pressure 0\n",
"Work Pressure 0\n",
"CGPA 0\n",
"Study Satisfaction 0\n",
"Job Satisfaction 0\n",
"Sleep Duration 0\n",
"Dietary Habits 0\n",
"Degree 0\n",
"Have you ever had suicidal thoughts ? 0\n",
"Work/Study Hours 0\n",
"Financial Stress 3\n",
"Family History of Mental Illness 0\n",
"Depression 0\n",
"dtype: int64"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.isnull().sum()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df.dropna(subset=['Financial Stress'], inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABdEAAAPdCAYAAABlRyFLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhU5f//8RejLAoCgoKoKIak4YKmue9hRqaYpmWLotjqkkt9yj7llrmVqeVSGYmlZakplrnghrnmknxyyZSsLBEXFMQFlDm/P/wx3yZAQcEj+HxcF5fOfc6cec2Avg/vuec+DoZhGAIAAAAAAAAAANlYzA4AAAAAAAAAAMDtiiY6AAAAAAAAAAC5oIkOAAAAAAAAAEAuaKIDAAAAAAAAAJALmugAAAAAAAAAAOSCJjoAAAAAAAAAALmgiQ4AAAAAAAAAQC5oogMAAAAAAAAAkAua6AAAAAAAAAAA5IImOnAHcnBw0KhRo8yOYWfHjh1q1qyZXF1d5eDgoD179pgdCQBwm6KOFSwzX8+AgABFRESY8tgAgKKN84GbExAQoIcfftjsGECRQRMdKEDR0dFycHCw+/Lx8VHbtm21YsUKs+PdtP3792vUqFH6/fffC/S4ly9fVvfu3ZWcnKwpU6bo888/V9WqVa97v++//14ODg6qWLGirFZrgWYCgDsRdezGUMduDwEBAdl+dlu2bKklS5aYHQ0AihTOB25MXs8HfvzxRzk4OGjKlCnZtoWHh8vBwUFz5szJtq1Vq1aqVKlSgWa+UW3atLH7+fDy8tJ9992nTz/9lHMaFFslzQ4AFEdjxoxRtWrVZBiGkpKSFB0drYceekjffvttkX6nd//+/Ro9erTatGmjgICAAjtuQkKC/vjjD82ePVv9+vXL8/3mz5+vgIAA/f7771q3bp1CQ0MLLBMA3MmoY/lDHbtxBw8elMVScPN66tWrp2HDhkmSjh07po8++khdu3bVrFmz9PzzzxfY4wDAnYDzgfzJ6/nAvffeq9KlS2vTpk0aMmSI3bYtW7aoZMmS2rx5s/r06WMbz8jI0I4dO9SpU6cCy3uzKleurPHjx0uSTp48qc8++0yRkZH69ddfNWHCBJPTAQWPJjpQCMLCwtSwYUPb7cjISPn6+urLL78s0icbheXEiROSJE9Pzzzf5/z584qJidH48eM1Z84czZ8//45sPgBAYaCO5Q917MY5OzsX6PEqVaqkp556yna7V69eql69uqZMmZJrE/3KlSuyWq1ycnIq0CyF6fz583J1dTU7BoBijvOB/Mnr+UDJkiXVuHFjbd682W784MGDOnXqlJ544glt2rTJbtuuXbt06dIltWjR4qZzXrhwQaVLl77p43h4eNjV3Oeee041atTQ9OnT9dZbb8nR0THbfaxWqzIyMuTi4nLTj3+rXLp0SU5OTgX6pj+KJn4CgFvA09NTpUqVUsmS9u9bnT9/XsOGDZO/v7+cnZ1Vo0YNvfvuuzIMQ5J08eJF1axZUzVr1tTFixdt90tOTpafn5+aNWumzMxMSVJERITc3Nz022+/qUOHDnJ1dVXFihU1ZswY2/Gu5aefflJYWJjc3d3l5uam+++/X9u2bbNtj46OVvfu3SVJbdu2tX1sa8OGDdc87rp169SyZUu5urrK09NT4eHhOnDggG17RESEWrduLUnq3r27HBwc1KZNm+vmXbJkiS5evKju3bvr8ccf1zfffKNLly5l2+/ixYsaNGiQypUrpzJlyqhz5876+++/c1w/7++//1bfvn3l6+srZ2dn1apVS59++ul1swBAcUcdM6+OXbp0SaNGjdLdd98tFxcX+fn5qWvXrkpISLDt8+6776pZs2by9vZWqVKl1KBBAy1atCjbsdLT0zVkyBCVL1/eVhP/+uuvHPPlpSZu2LBBDg4O+vrrrzV69GhVqlRJZcqU0aOPPqqUlBSlp6dr8ODB8vHxkZubm/r06aP09HS7Y+S0JvrZs2c1ZMgQBQQEyNnZWZUrV1avXr106tSp676u/1ahQgXdc889OnLkiCTp999/l4ODg959911NnTpVgYGBcnZ21v79+yVJv/zyix599FF5eXnJxcVFDRs21LJly+yOefnyZY0ePVpBQUFycXGRt7e3WrRoodjYWNs+x48fV58+fVS5cmU5OzvLz89P4eHhdksH5LaW779fk6xlFeLi4vTiiy/Kx8dHlStXtm1fsWKF7We0TJky6tixo/bt25fv1woArofzgYI7H2jRooWSkpJ0+PBh29jmzZvl7u6uZ5991tZQ/+e2rPtlmTlzpmrVqiVnZ2dVrFhR/fv319mzZ+0ep02bNqpdu7Z27dqlVq1aqXTp0nr99ddzzTV37lyVLFlSr7zyyjVfj5yULl1aTZo00fnz53Xy5ElJV2vdgAEDNH/+fFvWlStXSsr7798ffPCBatWqpdKlS6ts2bJq2LChvvjiC9v2c+fOafDgwbbzBh8fH7Vv3167d++27ZPbNVjatGlj933KOrdZsGCB3njjDVWqVEmlS5dWamqqJGn79u168MEH5eHhodKlS6t169bZ3gxB8cVMdKAQpKSk6NSpUzIMQydOnNAHH3ygtLQ0u3dpDcNQ586dtX79ekVGRqpevXpatWqVXnnlFf3999+aMmWKSpUqpblz56p58+b673//q/fee0+S1L9/f6WkpCg6OlolSpSwHTMzM1MPPvigmjRpokmTJmnlypUaOXKkrly5ojFjxuSad9++fWrZsqXc3d31n//8R46Ojvroo4/Upk0bxcXFqXHjxmrVqpUGDRqk999/X6+//rruueceSbL9mZM1a9YoLCxMd911l0aNGqWLFy/qgw8+UPPmzbV7924FBAToueeeU6VKlTRu3DgNGjRI9913n3x9fa/7Gs+fP19t27ZVhQoV9Pjjj+u1117Tt99+azshyhIREaGvv/5aTz/9tJo0aaK4uDh17Ngx2/GSkpLUpEkTW5EvX768VqxYocjISKWmpmrw4MHXzQQAxQV17Cqz61hmZqYefvhhrV27Vo8//rheeuklnTt3TrGxsdq7d68CAwMlSdOmTVPnzp315JNPKiMjQwsWLFD37t313Xff2dW8fv36ad68eXriiSfUrFkzrVu3rkBq4vjx41WqVCm99tprOnz4sD744AM5OjrKYrHozJkzGjVqlLZt26bo6GhVq1ZNI0aMyPV1SUtLU8uWLXXgwAH17dtX9957r06dOqVly5bpr7/+Urly5a772v7T5cuXdfToUXl7e9uNz5kzR5cuXdKzzz4rZ2dneXl5ad++fWrevLkqVaqk1157Ta6urvr666/VpUsXLV68WI888ogkadSoURo/frz69eunRo0aKTU1VTt37tTu3bvVvn17SVK3bt20b98+DRw4UAEBATpx4oRiY2P1559/3vDSAS+++KLKly+vESNG6Pz585Kkzz//XL1791aHDh00ceJEXbhwQbNmzVKLFi30008/FegyBQDuPJwPXFUY5wNZzfBNmzapevXqkq42yps0aaLGjRvL0dFRW7ZsUefOnW3bypQpo5CQEElXa9Ho0aMVGhqqF154QQcPHtSsWbO0Y8cObd682W4W+OnTpxUWFqbHH39cTz31VK65Pv74Yz3//PN6/fXXNXbs2FyzX8tvv/2mEiVK2M3GX7dunb7++msNGDBA5cqVU0BAQJ7PNWbPnq1Bgwbp0Ucf1UsvvaRLly7pf//7n7Zv364nnnhCkvT8889r0aJFGjBggIKDg3X69Glt2rRJBw4c0L333ntDz+Ott96Sk5OTXn75ZaWnp8vJyUnr1q1TWFiYGjRooJEjR8pisWjOnDlq166dfvjhBzVq1OiGHgtFiAGgwMyZM8eQlO3L2dnZiI6Ottt36dKlhiRj7NixduOPPvqo4eDgYBw+fNg2Nnz4cMNisRgbN240Fi5caEgypk6dane/3r17G5KMgQMH2sasVqvRsWNHw8nJyTh58qRtXJIxcuRI2+0uXboYTk5ORkJCgm3s2LFjRpkyZYxWrVrZxrIee/369Xl6PerVq2f4+PgYp0+fto3Fx8cbFovF6NW
"text/plain": [
"<Figure size 1500x1000 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', \n",
" 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress', 'Depression']\n",
"\n",
"plt.figure(figsize=(15, 10))\n",
"for i, feature in enumerate(features, 1):\n",
" plt.subplot(3, 3, i)\n",
" sns.boxplot(y=df[feature], color='skyblue')\n",
" plt.title(f'Boxplot of {feature}')\n",
" plt.ylabel(feature)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В Age много выбросов. Сбалансируем данные"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq4AAAH9CAYAAADbDf7CAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmsklEQVR4nO3dfZTWdZ3/8ddwM4MKDKIywDqYqes9aWzJaJEISWjenNBy0xLD7AZ1FbZ1UXe92QqyDKzILVfFo5AdS3OtUMEU16O0huFNbqxQJid0RNEZQBkQrt8fu87PWUEhgYsPPh7nXOd0fa7v9bneA53j83z5Xt+pqVQqlQAAwDauU7UHAACAjSFcAQAognAFAKAIwhUAgCIIVwAAiiBcAQAognAFAKAIwhUAgCIIVwAAiiBcAbaimpqaXHrppdUeo4OHH344hx9+eHbaaafU1NRk/vz51R4JYL2EK7BdmDZtWmpqajo8+vTpk6FDh2bmzJnVHu8de/LJJ3PppZfm6aef3qz7rlmzJieffHKWLVuWyZMn58Ybb8wee+zxtu/75S9/mZqamvTv3z/r1q3brDMBbEiXag8AsDldfvnl2XPPPVOpVNLc3Jxp06blmGOOyR133JGPf/zj1R7vL/bkk0/msssuy5FHHpn3vOc9m23fRYsW5U9/+lOuueaanHnmmRv9vunTp+c973lPnn766fzqV7/K8OHDN9tMABvijCuwXRk5cmROO+20fOYzn8nf//3f5z/+4z/StWvX/OhHP6r2aNuk559/PknSq1evjX7PypUrc/vtt2fcuHE59NBDM3369C00HUBHwhXYrvXq1Ss77LBDunTp+A9MK1euzPjx49PY2Ji6urrsu++++da3vpVKpZIkefXVV7Pffvtlv/32y6uvvtr+vmXLlqVfv345/PDDs3bt2iTJ6NGj07179/zhD3/IiBEjstNOO6V///65/PLL2/d7K7/97W8zcuTI9OzZM927d8+wYcMyd+7c9tenTZuWk08+OUkydOjQ9ksh7rvvvrfc91e/+lU+/OEPZ6eddkqvXr1ywgkn5L/+67/aXx89enQ+8pGPJElOPvnk1NTU5Mgjj3zbeW+77ba8+uqrOfnkk3PKKafk1ltvzapVq9503Kuvvppzzz03u+66a3r06JHjjz8+f/7zn9d7ne+f//znfO5zn0tDQ0Pq6upy4IEH5rrrrnvbWYB3F+EKbFdaWlrywgsvZOnSpfnd736XL33pS1mxYkVOO+209mMqlUqOP/74TJ48OR/72Mfy7W9/O/vuu2++8pWvZNy4cUmSHXbYITfccEMWLlyYiy66qP29Y8eOTUtLS6ZNm5bOnTu3r69duzYf+9jH0tDQkCuuuCKDBg3KJZdckksuueQt5/3d736XD3/4w3n00UfzD//wD/mnf/qn/PGPf8yRRx6ZX//610mSIUOG5Nxzz02SXHjhhbnxxhtz4403Zv/999/gvrNnz86IESPy/PPP59JLL824cePy4IMP5ogjjmi/TvYLX/hCLrzwwiTJueeemxtvvLHDz7oh06dPz9ChQ9O3b9+ccsopWb58ee644443HTd69Oh897vfzTHHHJNvfOMb2WGHHXLssce+6bjm5uYMHjw4s2fPztlnn52rrroqe++9d8aMGZMpU6a87TzAu0gFYDtw/fXXV5K86VFXV1eZNm1ah2N/9rOfVZJUvvrVr3ZYP+mkkyo1NTWVhQsXtq9NmDCh0qlTp8r9999fueWWWypJKlOmTOnwvtNPP72SpHLOOee0r61bt65y7LHHVmpraytLly5tX09SueSSS9qfn3jiiZXa2trKokWL2teWLFlS6dGjR2XIkCHta69/9r333rtRfx6HHHJIpU+fPpUXX3yxfe3RRx+tdOrUqfLZz362fe3ee++tJKnccsstG7Vvc3NzpUuXLpVrrrmmfe3www+vnHDCCR2OmzdvXiVJ5bzzzuuwPnr06Df9GYwZM6bSr1+/ygsvvNDh2FNOOaVSX19feeWVVzZqNmD754wrsF2ZOnVqZs2alVmzZuWmm27K0KFDc+aZZ+bWW29tP+aXv/xlOnfu3H4W83Xjx49PpVLpcBeCSy+9NAceeGBOP/30fPnLX85HPvKRN73vdWeffXb7/66pqcnZZ5+d1atXZ/bs2es9fu3atbn77rtz4okn5r3vfW/7er9+/fLpT386DzzwQFpbWzf5z+DZZ5/N/PnzM3r06PTu3bt9feDAgfnoRz+aX/7yl5u85+tuvvnmdOrUKaNGjWpf+9u//dvMnDkzL730UvvanXfemST58pe/3OH955xzTofnlUolP/3pT3PcccelUqnkhRdeaH+MGDEiLS0teeSRR/7ieYHti3AFtisf/OAHM3z48AwfPjynnnpqfvGLX+SAAw5oj8gk+dOf/pT+/funR48eHd77+j+9/+lPf2pfq62tzXXXXZc//vGPWb58ea6//vrU1NS86XM7derUIT6T5K//+q+TZIO3sFq6dGleeeWV7Lvvvm96bf/998+6deuyePHijf/h/9fr829o3xdeeCErV67c5H2T5KabbsoHP/jBvPjii1m4cGEWLlyYQw89NKtXr84tt9zSYYZOnTplzz337PD+vffeu8PzpUuX5uWXX84Pf/jD7Lbbbh0eZ5xxRpL//wUyALfDArZrnTp1ytChQ3PVVVflqaeeyoEHHrjJe9x1111JklWrVuWpp556U4y9Wzz11FN5+OGHkyT77LPPm16fPn16zjrrrE3a8/V7wJ522mk5/fTT13vMwIEDN3FSYHslXIHt3muvvZYkWbFiRZJkjz32yOzZs7N8+fIOZ11///vft7/+usceeyyXX355zjjjjMyfPz9nnnlmHn/88dTX13f4jHXr1uUPf/hD+1nWJPnv//7vJNngfVd322237LjjjlmwYMGbXvv973+fTp06pbGxMUnWe5Z3Q16ff0P77rrrrtlpp502er/XTZ8+PV27ds2NN97Y4YtpSfLAAw/kO9/5Tp555pkMGDAge+yxR9atW5c//vGPHSJ34cKFHd632267pUePHlm7dq17wQJvy6UCwHZtzZo1ufvuu1NbW9t+KcAxxxyTtWvX5nvf+16HYydPnpyampqMHDmy/b2jR49O//79c9VVV2XatGlpbm7O+eefv97PeuN+lUol3/ve99K1a9cMGzZsvcd37tw5Rx99dG6//fYOlxM0NzdnxowZ+dCHPpSePXsmSXtovvzyy2/7M/fr1y+HHHJIbrjhhg7HP/HEE7n77rtzzDHHvO0e6zN9+vR8+MMfzqc+9amcdNJJHR5f+cpXkqT9frkjRoxIknz/+9/vsMd3v/vdDs87d+6cUaNG5ac//WmeeOKJN33m0qVL/6JZge2TM67AdmXmzJntZ06ff/75zJgxI0899VT+8R//sT0CjzvuuAwdOjQXXXRRnn766bzvfe/L3Xffndtvvz3nnXde9tprryTJV7/61cyfPz/33HNPevTokYEDB+af//mfc/HFF+ekk07qEIDdunXLnXfemdNPPz2HHXZYZs6cmV/84he58MILs9tuu21w3q9+9auZNWtWPvShD+XLX/5yunTpkh/84Adpa2vLFVdc0X7cIYccks6dO+cb3/hGWlpaUldXl6OOOip9+vRZ777f/OY3M3LkyDQ1NWXMmDF59dVX893vfjf19fVvuofqxvj1r3+dhQsXdvgC2hv91V/9Vd7//vdn+vTpueCCCzJo0KCMGjUqU6ZMyYsvvpjBgwdnzpw57Weh33gGedKkSbn33ntz2GGH5fOf/3wOOOCALFu2LI888khmz56dZcuWbfK8wHaqujc1ANg81nc7rG7dulUOOeSQytVXX11Zt25dh+OXL19eOf/88yv9+/evdO3atbLPPvtUvvnNb7YfN2/evEqXLl063OKqUqlUXnvttcoHPvCBSv/+/SsvvfRSpVL5n9th7bTTTpVFixZVjj766MqOO+5YaWhoqFxyySWVtWvXdnh//s+toCqVSuWRRx6pjBgxotK9e/fKjjvuWBk6dGjlwQcffNPPeM0111Te+973Vjp
"text/plain": [
"<Figure size 800x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"Q1 = df['Age'].quantile(0.25)\n",
"Q3 = df['Age'].quantile(0.75)\n",
"IQR = Q3 - Q1\n",
"\n",
"threshold = 1.5 * IQR\n",
"outliers = (df['Age'] < (Q1 - threshold)) | (df['Age'] > (Q3 + threshold))\n",
"\n",
"median_rating = df['Age'].median()\n",
"df.loc[outliers, 'Age'] = median_rating\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"sns.boxplot(y=df['Age'], color='skyblue')\n",
"plt.title('Boxplot of Age')\n",
"plt.ylabel('Age')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Конструирование признаков с помощью меток"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"le = LabelEncoder()\n",
"df['Gender'] = le.fit_transform(df['Gender'])\n",
"df['City'] = le.fit_transform(df['City'])\n",
"df['Dietary Habits'] = le.fit_transform(df['Dietary Habits'])\n",
"df['Degree'] = le.fit_transform(df['Degree'])\n",
"df['Have you ever had suicidal thoughts ?'] = le.fit_transform(df['Have you ever had suicidal thoughts ?'])\n",
"df['Sleep Duration'] = le.fit_transform(df['Sleep Duration'])\n",
"df['Profession'] = le.fit_transform(df['Profession'])\n",
"df['Study Satisfaction'] = le.fit_transform(df['Study Satisfaction'])\n",
"df['Family History of Mental Illness'] = le.fit_transform(df['Family History of Mental Illness'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"разделение на признаки и целевую переменную"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"x = df.drop('Depression', axis=1)\n",
"y = df['Depression']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)"
]
},
2024-12-21 12:33:06 +04:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Создание конвейера\n",
"\n",
"# Обработаем данные\n",
"# Определим категориальные и числовые признаки\n",
"categorical_features = ['Gender', 'City', 'Dietary Habits', 'Degree', 'Have you ever had suicidal thoughts ?', 'Profession', 'Family History of Mental Illness', 'Sleep Duration']\n",
"numerical_features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress']\n",
"\n",
"categorical_transformer = Pipeline(steps=[\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore'))\n",
"])\n",
"\n",
"numerical_transformer = Pipeline(steps=[\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"preprocessor = ColumnTransformer(\n",
" transformers=[\n",
" ('num', numerical_transformer, numerical_features),\n",
" ('cat', categorical_transformer, categorical_features)\n",
" ])\n",
"\n",
"# Построим модели\n",
"pipeline_lasso = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', Lasso())\n",
"])\n",
"\n",
"pipeline_gb = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', GradientBoostingRegressor())\n",
"])\n",
"\n",
"pipeline_knn = Pipeline(steps=[\n",
" ('preprocessor', preprocessor),\n",
" ('model', KNeighborsRegressor())\n",
"])"
]
},
2024-12-21 00:25:04 +04:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1) Метод регрессии Лассо\n"
]
},
{
"cell_type": "code",
2024-12-21 12:33:06 +04:00
"execution_count": 1,
2024-12-21 00:25:04 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для Lasso:\n",
2024-12-21 12:33:06 +04:00
"{'model__alpha': 0.01, 'model__fit_intercept': False}\n"
2024-12-21 00:25:04 +04:00
]
}
],
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"param_grid_lasso = {\n",
2024-12-21 12:33:06 +04:00
" 'model__alpha': [0.01, 0.1, 1.0, 10.0],\n",
" 'model__fit_intercept': [True, False],\n",
2024-12-21 00:25:04 +04:00
"}\n",
"\n",
"# Создание объекта GridSearchCV\n",
"grid_search_lasso = GridSearchCV(\n",
" estimator=Lasso(), \n",
" param_grid=param_grid_lasso, \n",
" cv=5, \n",
" scoring='neg_mean_squared_error', \n",
" n_jobs=-1 \n",
")\n",
"\n",
"grid_search_lasso.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для Lasso:\")\n",
"print(grid_search_lasso.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2) Метод градиентного бустинга"
]
},
{
"cell_type": "code",
2024-12-21 12:33:06 +04:00
"execution_count": 2,
2024-12-21 00:25:04 +04:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для Gradient Boosting:\n",
2024-12-21 12:33:06 +04:00
"{'model__learning_rate': 0.1, 'model__max_depth': 5, 'model__max_features': 'sqrt', 'model__min_samples_leaf': 2, 'model__min_samples_split': 5, 'model__n_estimators': 100}\n"
2024-12-21 00:25:04 +04:00
]
}
],
"source": [
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
"param_grid_gb = {\n",
2024-12-21 12:33:06 +04:00
" 'model__n_estimators': [50, 100, 200],\n",
" 'model__learning_rate': [0.01, 0.1, 0.2],\n",
" 'model__max_depth': [3, 5, 7],\n",
" 'model__min_samples_split': [2, 5, 10],\n",
" 'model__min_samples_leaf': [1, 2, 4],\n",
" 'model__max_features': ['auto', 'sqrt', 'log2']\n",
2024-12-21 00:25:04 +04:00
"}\n",
"\n",
"grid_search_gb = GridSearchCV(\n",
" estimator=GradientBoostingRegressor(),\n",
" param_grid=param_grid_gb,\n",
" cv=5,\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1\n",
")\n",
"\n",
"grid_search_gb.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для Gradient Boosting:\")\n",
"print(grid_search_gb.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3) Метод k-ближайших соседей"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для k-Nearest Neighbors:\n",
"{'algorithm': 'ball_tree', 'n_neighbors': 10, 'p': 1, 'weights': 'distance'}\n"
]
}
],
"source": [
"from sklearn.neighbors import KNeighborsRegressor\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"param_grid_knn = {\n",
2024-12-21 12:33:06 +04:00
" 'model__n_neighbors': [3, 5, 7, 10],\n",
" 'model__weights': ['uniform', 'distance'],\n",
" 'model__algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n",
" 'model__p': [1, 2]\n",
2024-12-21 00:25:04 +04:00
"}\n",
"\n",
"grid_search_knn = GridSearchCV(\n",
" estimator=KNeighborsRegressor(),\n",
" param_grid=param_grid_knn,\n",
" cv=5,\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1\n",
")\n",
"\n",
"grid_search_knn.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для k-Nearest Neighbors:\")\n",
"print(grid_search_knn.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Предсказание на тестовой выборке"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
2024-12-21 12:33:06 +04:00
"y_pred_lasso = grid_search_lasso.predict(x_test)\n",
"y_pred_forest = grid_search_gb.predict(x_test)\n",
"y_pred_neighbors = grid_search_knn.predict(x_test)"
2024-12-21 00:25:04 +04:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценка качества модели"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.\tMSE (Mean Squared Error)\n",
"Среднее значение квадратов разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Squared Error (MSE):\n",
"k-NN: \t\t\t0.213\n",
"Random Forest: \t\t0.118\n",
"Lasso: \t\t\t0.166\n",
"Gradient Boosting: \t0.113\n",
"k-Nearest Neighbors: \t0.326\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"import numpy as np\n",
"\n",
"mse1 = mean_squared_error(y_test, y_pred)\n",
"mse2 = mean_squared_error(y_test, y_pred_forest)\n",
"mse3 = mean_squared_error(y_test, y_pred_lasso)\n",
"mse4 = mean_squared_error(y_test, y_pred_gb)\n",
"mse5 = mean_squared_error(y_test, y_pred_neighbors)\n",
"\n",
"mse1_rounded = round(mse1, 3)\n",
"mse2_rounded = round(mse2, 3)\n",
"mse3_rounded = round(mse3, 3)\n",
"mse4_rounded = round(mse4, 3)\n",
"mse5_rounded = round(mse5, 3)\n",
"\n",
"print(\"Mean Squared Error (MSE):\")\n",
"print(f\"k-NN: \\t\\t\\t{mse1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{mse2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{mse3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{mse4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{mse5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.\tMAE\n",
"Среднее значение абсолютных разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Absolute Error (MAE):\n",
"k-NN: \t\t\t0.213\n",
"Random Forest: \t\t0.238\n",
"Lasso: \t\t\t0.366\n",
"Gradient Boosting: \t0.246\n",
"k-Nearest Neighbors: \t0.485\n"
]
}
],
"source": [
"from sklearn.metrics import mean_absolute_error\n",
"\n",
"mae1 = round(mean_absolute_error(y_test, y_pred),3)\n",
"mae2 = round(mean_absolute_error(y_test, y_pred_forest),3)\n",
"mae3 = round(mean_absolute_error(y_test, y_pred_lasso),3)\n",
"mae4 = round(mean_absolute_error(y_test, y_pred_gb),3)\n",
"mae5 = round(mean_absolute_error(y_test, y_pred_neighbors),3)\n",
"print(\"Mean Absolute Error (MAE):\")\n",
"print(f\"k-NN: \\t\\t\\t{mae1}\")\n",
"print(f\"Random Forest: \\t\\t{mae2}\")\n",
"print(f\"Lasso: \\t\\t\\t{mae3}\")\n",
"print(f\"Gradient Boosting: \\t{mae4}\")\n",
"print(f\"k-Nearest Neighbors: \\t{mae5}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.\tR-squared\n",
"Мера, показывающая, насколько хорошо модель объясняет изменчивость данных. Значение находится в диапазоне от 0 до 1, где 1 — идеальное соответствие, а 0 — модель не объясняет данные."
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R² (R-squared): 0.127933821917115\n",
"\n",
"R² (R-squared):\n",
"k-NN: \t\t\t0.128\n",
"Random Forest: \t\t0.515\n",
"Lasso: \t\t\t0.319\n",
"Gradient Boosting: \t0.537\n",
"k-Nearest Neighbors: \t-0.337\n"
]
}
],
"source": [
"from sklearn.metrics import r2_score\n",
"r2 = r2_score(y_test, y_pred)\n",
"print(f\"R² (R-squared): {r2}\")\n",
"\n",
"r2_1 = r2_score(y_test, y_pred)\n",
"r2_2 = r2_score(y_test, y_pred_forest)\n",
"r2_3 = r2_score(y_test, y_pred_lasso)\n",
"r2_4 = r2_score(y_test, y_pred_gb)\n",
"r2_5 = r2_score(y_test, y_pred_neighbors)\n",
"\n",
"r2_1_rounded = round(r2_1, 3)\n",
"r2_2_rounded = round(r2_2, 3)\n",
"r2_3_rounded = round(r2_3, 3)\n",
"r2_4_rounded = round(r2_4, 3)\n",
"r2_5_rounded = round(r2_5, 3)\n",
"\n",
"print(\"\\nR² (R-squared):\")\n",
"print(f\"k-NN: \\t\\t\\t{r2_1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{r2_2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{r2_3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{r2_4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{r2_5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4.\tRMSE\n",
" Среднее отклонение предсказаний от реальных данных. Чем меньше модуль, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Root Mean Squared Error (RMSE):\n",
"k-NN: \t\t\t0.461\n",
"Random Forest: \t\t0.344\n",
"Lasso: \t\t\t0.407\n",
"Gradient Boosting: \t0.336\n",
"k-Nearest Neighbors: \t0.571\n"
]
}
],
"source": [
"rmse1 = np.sqrt(mse1)\n",
"rmse2 = np.sqrt(mse2)\n",
"rmse3 = np.sqrt(mse3)\n",
"rmse4 = np.sqrt(mse4)\n",
"rmse5 = np.sqrt(mse5)\n",
"\n",
"rmse1_rounded = round(rmse1, 3)\n",
"rmse2_rounded = round(rmse2, 3)\n",
"rmse3_rounded = round(rmse3, 3)\n",
"rmse4_rounded = round(rmse4, 3)\n",
"rmse5_rounded = round(rmse5, 3)\n",
"\n",
"print(\"Root Mean Squared Error (RMSE):\")\n",
"print(f\"k-NN: \\t\\t\\t{rmse1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{rmse2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{rmse3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{rmse4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{rmse5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучший результат градиентный бустинг и случайный лес.\n",
"Положительные результаты по всем критериям получил случайный лес. Три из четырех положительных результата у градиентного бустинга. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значит, случайный лес наиболее точная и устойчивая стратегия обучения модели. Итоговая модель model_forest."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Также, с помощью применение важности признаков (feature importance) на Случайном лесе, мы вывели основные факторы, вызывающие депрессию:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Feature Importance\n",
"13 Have you ever had suicidal thoughts ? 0.300542\n",
"5 Academic Pressure 0.134276\n",
"0 id 0.087970\n",
"7 CGPA 0.079078\n",
"2 Age 0.066613\n",
"15 Financial Stress 0.066330\n",
"3 City 0.059293\n",
"14 Work/Study Hours 0.052275\n",
"12 Degree 0.049539\n",
"8 Study Satisfaction 0.032944\n",
"11 Dietary Habits 0.026140\n",
"10 Sleep Duration 0.024435\n",
"16 Family History of Mental Illness 0.010547\n",
"1 Gender 0.009627\n",
"4 Profession 0.000372\n",
"9 Job Satisfaction 0.000017\n",
"6 Work Pressure 0.000003\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"model_rf = RandomForestRegressor(n_estimators=100, random_state=42)\n",
"model_rf.fit(x_train, y_train)\n",
"\n",
"feature_importances = model_rf.feature_importances_\n",
"\n",
"import pandas as pd\n",
"feature_importance_df = pd.DataFrame({\n",
" 'Feature': x.columns,\n",
" 'Importance': feature_importances\n",
"}).sort_values(by='Importance', ascending=False)\n",
"\n",
"print(feature_importance_df)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scripts",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}