AIM-PIbd-32-Isaeva-A-I/lab_4/Lab4.ipynb

912 lines
161 KiB
Plaintext
Raw Normal View History

2024-12-20 23:47:13 +04:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'Gender', 'Age', 'City', 'Profession', 'Academic Pressure',\n",
" 'Work Pressure', 'CGPA', 'Study Satisfaction', 'Job Satisfaction',\n",
" 'Sleep Duration', 'Dietary Habits', 'Degree',\n",
" 'Have you ever had suicidal thoughts ?', 'Work/Study Hours',\n",
" 'Financial Stress', 'Family History of Mental Illness', 'Depression'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from matplotlib.ticker import FuncFormatter\n",
"\n",
"df = pd.read_csv(\".//csv//Student Depression Dataset.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" id Gender Age City Profession Academic Pressure \\\n",
"0 2 Male 33.0 Visakhapatnam Student 5.0 \n",
"1 8 Female 24.0 Bangalore Student 2.0 \n",
"2 26 Male 31.0 Srinagar Student 3.0 \n",
"3 30 Female 28.0 Varanasi Student 3.0 \n",
"4 32 Female 25.0 Jaipur Student 4.0 \n",
"\n",
" Work Pressure CGPA Study Satisfaction Job Satisfaction \\\n",
"0 0.0 8.97 2.0 0.0 \n",
"1 0.0 5.90 5.0 0.0 \n",
"2 0.0 7.03 5.0 0.0 \n",
"3 0.0 5.59 2.0 0.0 \n",
"4 0.0 8.13 3.0 0.0 \n",
"\n",
" Sleep Duration Dietary Habits Degree \\\n",
"0 5-6 hours Healthy B.Pharm \n",
"1 5-6 hours Moderate BSc \n",
"2 Less than 5 hours Healthy BA \n",
"3 7-8 hours Moderate BCA \n",
"4 5-6 hours Moderate M.Tech \n",
"\n",
" Have you ever had suicidal thoughts ? Work/Study Hours Financial Stress \\\n",
"0 Yes 3.0 1.0 \n",
"1 No 3.0 2.0 \n",
"2 No 9.0 1.0 \n",
"3 Yes 4.0 5.0 \n",
"4 Yes 1.0 1.0 \n",
"\n",
" Family History of Mental Illness Depression \n",
"0 No 1 \n",
"1 Yes 0 \n",
"2 Yes 0 \n",
"3 Yes 1 \n",
"4 No 0 \n"
]
}
],
"source": [
"print(df.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Бизнес-цель исследования\n",
"Разработать и внедрить систему прогнозирования уровня депрессии среди обучающихся, которая позволит выявить группы риска на ранних этапах. Результаты исследования могут быть полезны психологам, педагогам и администрации учебных заведений.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Описание набора данных для анализа\n",
"Набор данных содержит информацию о психологическом состоянии обучающихся и включает следующие поля:\n",
"- id идентификатор, число\n",
"- Gender пол, строка\n",
"- Age возраст, дробное число\n",
"- City город, строка\n",
"- Profession профессия, строка\n",
"- Academic Pressure академическое давление, дробное число (от 1.00 до 5.00)\n",
"- Work Pressure рабочее давление, дробное число (от 1.00 до 5.00)\n",
"- CGPA средний балл (GPA), дробное число\n",
"- Study Satisfaction удовлетворенность учебой, дробное число (от 1.00 до 5.00)\n",
"- Job Satisfaction удовлетворенность работой, дробное число (от 1.00 до 5.00)\n",
"- Sleep Duration продолжительность сна, строка\n",
"- Dietary Habits пищевые привычки, строка\n",
"- Degree степень (образование), строка\n",
"- Have you ever had suicidal thoughts? Были ли у вас когда-либо суицидальные мысли? строка (yes/no)\n",
"- Work/Study Hours часы работы/учебы, дробное число\n",
"- Financial Stress финансовый стресс, дробное число (от 1.00 до 5.00)\n",
"- Family History of Mental Illness семейный анамнез психических заболеваний, строка (yes/no)\n",
"- Depression депрессия, булевое значение (1/0)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Обработка данных"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"id 0\n",
"Gender 0\n",
"Age 0\n",
"City 0\n",
"Profession 0\n",
"Academic Pressure 0\n",
"Work Pressure 0\n",
"CGPA 0\n",
"Study Satisfaction 0\n",
"Job Satisfaction 0\n",
"Sleep Duration 0\n",
"Dietary Habits 0\n",
"Degree 0\n",
"Have you ever had suicidal thoughts ? 0\n",
"Work/Study Hours 0\n",
"Financial Stress 3\n",
"Family History of Mental Illness 0\n",
"Depression 0\n",
"dtype: int64"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.isnull().sum()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df.dropna(subset=['Financial Stress'], inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABdEAAAPdCAYAAABlRyFLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVhU5f//8RejLAoCgoKoKIak4YKmue9hRqaYpmWLotjqkkt9yj7llrmVqeVSGYmlZakplrnghrnmknxyyZSsLBEXFMQFlDm/P/wx3yZAQcEj+HxcF5fOfc6cec2Avg/vuec+DoZhGAIAAAAAAAAAANlYzA4AAAAAAAAAAMDtiiY6AAAAAAAAAAC5oIkOAAAAAAAAAEAuaKIDAAAAAAAAAJALmugAAAAAAAAAAOSCJjoAAAAAAAAAALmgiQ4AAAAAAAAAQC5oogMAAAAAAAAAkAua6AAAAAAAAAAA5IImOnAHcnBw0KhRo8yOYWfHjh1q1qyZXF1d5eDgoD179pgdCQBwm6KOFSwzX8+AgABFRESY8tgAgKKN84GbExAQoIcfftjsGECRQRMdKEDR0dFycHCw+/Lx8VHbtm21YsUKs+PdtP3792vUqFH6/fffC/S4ly9fVvfu3ZWcnKwpU6bo888/V9WqVa97v++//14ODg6qWLGirFZrgWYCgDsRdezGUMduDwEBAdl+dlu2bKklS5aYHQ0AihTOB25MXs8HfvzxRzk4OGjKlCnZtoWHh8vBwUFz5szJtq1Vq1aqVKlSgWa+UW3atLH7+fDy8tJ9992nTz/9lHMaFFslzQ4AFEdjxoxRtWrVZBiGkpKSFB0drYceekjffvttkX6nd//+/Ro9erTatGmjgICAAjtuQkKC/vjjD82ePVv9+vXL8/3mz5+vgIAA/f7771q3bp1CQ0MLLBMA3MmoY/lDHbtxBw8elMVScPN66tWrp2HDhkmSjh07po8++khdu3bVrFmz9PzzzxfY4wDAnYDzgfzJ6/nAvffeq9KlS2vTpk0aMmSI3bYtW7aoZMmS2rx5s/r06WMbz8jI0I4dO9SpU6cCy3uzKleurPHjx0uSTp48qc8++0yRkZH69ddfNWHCBJPTAQWPJjpQCMLCwtSwYUPb7cjISPn6+urLL78s0icbheXEiROSJE9Pzzzf5/z584qJidH48eM1Z84czZ8//45sPgBAYaCO5Q917MY5OzsX6PEqVaqkp556yna7V69eql69uqZMmZJrE/3KlSuyWq1ycnIq0CyF6fz583J1dTU7BoBijvOB/Mnr+UDJkiXVuHFjbd682W784MGDOnXqlJ544glt2rTJbtuuXbt06dIltWjR4qZzXrhwQaVLl77p43h4eNjV3Oeee041atTQ9OnT9dZbb8nR0THbfaxWqzIyMuTi4nLTj3+rXLp0SU5OTgX6pj+KJn4CgFvA09NTpUqVUsmS9u9bnT9/XsOGDZO/v7+cnZ1Vo0YNvfvuuzIMQ5J08eJF1axZUzVr1tTFixdt90tOTpafn5+aNWumzMxMSVJERITc3Nz022+/qUOHDnJ1dVXFihU1ZswY2/Gu5aefflJYWJjc3d3l5uam+++/X9u2bbNtj46OVvfu3SVJbdu2tX1sa8OGDdc87rp169SyZUu5urrK09NT4eHhOnDggG17RESEWrduLUnq3r27HBwc1KZNm+vmXbJkiS5evKju3bvr8ccf1zfffKNLly5l2+/ixYsaNGiQypUrpzJlyqhz5876+++/c1w/7++//1bfvn3l6+srZ2dn1apVS59++ul1swBAcUcdM6+OXbp0SaNGjdLdd98tFxcX+fn5qWvXrkpISLDt8+6776pZs2by9vZWqVKl1KBBAy1atCjbsdLT0zVkyBCVL1/eVhP/+uuvHPPlpSZu2LBBDg4O+vrrrzV69GhVqlRJZcqU0aOPPqqUlBSlp6dr8ODB8vHxkZubm/r06aP09HS7Y+S0JvrZs2c1ZMgQBQQEyNnZWZUrV1avXr106tSp676u/1ahQgXdc889OnLkiCTp999/l4ODg959911NnTpVgYGBcnZ21v79+yVJv/zyix599FF5eXnJxcVFDRs21LJly+yOefnyZY0ePVpBQUFycXGRt7e3WrRoodjYWNs+x48fV58+fVS5cmU5OzvLz89P4eHhdksH5LaW779fk6xlFeLi4vTiiy/Kx8dHlStXtm1fsWKF7We0TJky6tixo/bt25fv1woArofzgYI7H2jRooWSkpJ0+PBh29jmzZvl7u6uZ5991tZQ/+e2rPtlmTlzpmrVqiVnZ2dVrFhR/fv319mzZ+0ep02bNqpdu7Z27dqlVq1aqXTp0nr99ddzzTV37lyVLFlSr7zyyjVfj5yULl1aTZo00fnz53Xy5ElJV2vdgAEDNH/+fFvWlStXSsr7798ffPCBatWqpdKlS6ts2bJq2LChvvjiC9v2c+fOafDgwbbzBh8fH7Vv3167d++27ZPbNVjatGlj933KOrdZsGCB3njjDVWqVEmlS5dWamqqJGn79u168MEH5eHhodKlS6t169bZ3gxB8cVMdKAQpKSk6NSpUzIMQydOnNAHH3ygtLQ0u3dpDcNQ586dtX79ekVGRqpevXpatWqVXnnlFf3999+aMmWKSpUqpblz56p58+b673//q/fee0+S1L9/f6WkpCg6OlolSpSwHTMzM1MPPvigmjRpokmTJmnlypUaOXKkrly5ojFjxuSad9++fWrZsqXc3d31n//8R46Ojvroo4/Upk0bxcXFqXHjxmrVqpUGDRqk999/X6+//rruueceSbL9mZM1a9YoLCxMd911l0aNGqWLFy/qgw8+UPPmzbV7924FBAToueeeU6VKlTRu3DgNGjRI9913n3x9fa/7Gs+fP19t27ZVhQoV9Pjjj+u1117Tt99+azshyhIREaGvv/5aTz/9tJo0aaK4uDh17Ngx2/GSkpLUpEkTW5EvX768VqxYocjISKWmpmrw4MHXzQQAxQV17Cqz61hmZqYefvhhrV27Vo8//rheeuklnTt3TrGxsdq7d68CAwMlSdOmTVPnzp315JNPKiMjQwsWLFD37t313Xff2dW8fv36ad68eXriiSfUrFkzrVu3rkBq4vjx41WqVCm99tprOnz4sD744AM5OjrKYrHozJkzGjVqlLZt26bo6GhVq1ZNI0aMyPV1SUtLU8uWLXXgwAH17dtX9957r06dOqVly5bpr7/+Urly5a772v7T5cuXdfToUXl7e9uNz5kzR5cuXdKzzz4rZ2dneXl5ad++fWrevLkqVaqk1157Ta6urvr666/VpUsXLV68WI888ogkadSoURo/frz69eunRo0aKTU1VTt37tTu3bvVvn17SVK3bt20b98+DRw4UAEBATpx4oRiY2P1559/3vDSAS+++KLKly+vESNG6Pz585Kkzz//XL1791aHDh00ceJEXbhwQbNmzVKLFi30008/FegyBQDuPJwPXFUY5wNZzfBNmzapevXqkq42yps0aaLGjRvL0dFRW7ZsUefOnW3bypQpo5CQEElXa9Ho0aMVGhqqF154QQcPHtSsWbO0Y8cObd682W4W+OnTpxUWFqbHH39cTz31VK65Pv74Yz3//PN6/fXXNXbs2FyzX8tvv/2mEiVK2M3GX7dunb7++msNGDBA5cqVU0BAQJ7PNWbPnq1Bgwbp0Ucf1UsvvaRLly7pf//7n7Zv364nnnhCkvT8889r0aJFGjBggIKDg3X69Glt2rRJBw4c0L333ntDz+Ott96Sk5OTXn75ZaWnp8vJyUnr1q1TWFiYGjRooJEjR8pisWjOnDlq166dfvjhBzVq1OiGHgtFiAGgwMyZM8eQlO3L2dnZiI6Ottt36dKlhiRj7NixduOPPvqo4eDgYBw+fNg2Nnz4cMNisRgbN240Fi5caEgypk6dane/3r17G5KMgQMH2sasVqvRsWNHw8nJyTh58qRtXJIxcuRI2+0uXboYTk5ORkJCgm3s2LFjRpkyZYxWrVrZxrIee/369Xl6PerVq2f4+PgYp0+fto3Fx8cbFovF6NW
"text/plain": [
"<Figure size 1500x1000 with 9 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"features = ['Age', 'Academic Pressure', 'Work Pressure', 'CGPA', 'Study Satisfaction', \n",
" 'Job Satisfaction', 'Work/Study Hours', 'Financial Stress', 'Depression']\n",
"\n",
"plt.figure(figsize=(15, 10))\n",
"for i, feature in enumerate(features, 1):\n",
" plt.subplot(3, 3, i)\n",
" sns.boxplot(y=df[feature], color='skyblue')\n",
" plt.title(f'Boxplot of {feature}')\n",
" plt.ylabel(feature)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В Age много выбросов. Сбалансируем данные"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAq4AAAH9CAYAAADbDf7CAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAmsklEQVR4nO3dfZTWdZ3/8ddwM4MKDKIywDqYqes9aWzJaJEISWjenNBy0xLD7AZ1FbZ1UXe92QqyDKzILVfFo5AdS3OtUMEU16O0huFNbqxQJid0RNEZQBkQrt8fu87PWUEhgYsPPh7nXOd0fa7v9bneA53j83z5Xt+pqVQqlQAAwDauU7UHAACAjSFcAQAognAFAKAIwhUAgCIIVwAAiiBcAQAognAFAKAIwhUAgCIIVwAAiiBcAbaimpqaXHrppdUeo4OHH344hx9+eHbaaafU1NRk/vz51R4JYL2EK7BdmDZtWmpqajo8+vTpk6FDh2bmzJnVHu8de/LJJ3PppZfm6aef3qz7rlmzJieffHKWLVuWyZMn58Ybb8wee+zxtu/75S9/mZqamvTv3z/r1q3brDMBbEiXag8AsDldfvnl2XPPPVOpVNLc3Jxp06blmGOOyR133JGPf/zj1R7vL/bkk0/msssuy5FHHpn3vOc9m23fRYsW5U9/+lOuueaanHnmmRv9vunTp+c973lPnn766fzqV7/K8OHDN9tMABvijCuwXRk5cmROO+20fOYzn8nf//3f5z/+4z/StWvX/OhHP6r2aNuk559/PknSq1evjX7PypUrc/vtt2fcuHE59NBDM3369C00HUBHwhXYrvXq1Ss77LBDunTp+A9MK1euzPjx49PY2Ji6urrsu++++da3vpVKpZIkefXVV7Pffvtlv/32y6uvvtr+vmXLlqVfv345/PDDs3bt2iTJ6NGj07179/zhD3/IiBEjstNOO6V///65/PLL2/d7K7/97W8zcuTI9OzZM927d8+wYcMyd+7c9tenTZuWk08+OUkydOjQ9ksh7rvvvrfc91e/+lU+/OEPZ6eddkqvXr1ywgkn5L/+67/aXx89enQ+8pGPJElOPvnk1NTU5Mgjj3zbeW+77ba8+uqrOfnkk3PKKafk1ltvzapVq9503Kuvvppzzz03u+66a3r06JHjjz8+f/7zn9d7ne+f//znfO5zn0tDQ0Pq6upy4IEH5rrrrnvbWYB3F+EKbFdaWlrywgsvZOnSpfnd736XL33pS1mxYkVOO+209mMqlUqOP/74TJ48OR/72Mfy7W9/O/vuu2++8pWvZNy4cUmSHXbYITfccEMWLlyYiy66qP29Y8eOTUtLS6ZNm5bOnTu3r69duzYf+9jH0tDQkCuuuCKDBg3KJZdckksuueQt5/3d736XD3/4w3n00UfzD//wD/mnf/qn/PGPf8yRRx6ZX//610mSIUOG5Nxzz02SXHjhhbnxxhtz4403Zv/999/gvrNnz86IESPy/PPP59JLL824cePy4IMP5ogjjmi/TvYLX/hCLrzwwiTJueeemxtvvLHDz7oh06dPz9ChQ9O3b9+ccsopWb58ee644443HTd69Oh897vfzTHHHJNvfOMb2WGHHXLssce+6bjm5uYMHjw4s2fPztlnn52rrroqe++9d8aMGZMpU6a87TzAu0gFYDtw/fXXV5K86VFXV1eZNm1ah2N/9rOfVZJUvvrVr3ZYP+mkkyo1NTWVhQsXtq9NmDCh0qlTp8r9999fueWWWypJKlOmTOnwvtNPP72SpHLOOee0r61bt65y7LHHVmpraytLly5tX09SueSSS9qfn3jiiZXa2trKokWL2teWLFlS6dGjR2XIkCHta69/9r333rtRfx6HHHJIpU+fPpUXX3yxfe3RRx+tdOrUqfLZz362fe3ee++tJKnccsstG7Vvc3NzpUuXLpVrrrmmfe3www+vnHDCCR2OmzdvXiVJ5bzzzuuwPnr06Df9GYwZM6bSr1+/ygsvvNDh2FNOOaVSX19feeWVVzZqNmD754wrsF2ZOnVqZs2alVmzZuWmm27K0KFDc+aZZ+bWW29tP+aXv/xlOnfu3H4W83Xjx49PpVLpcBeCSy+9NAceeGBOP/30fPnLX85HPvKRN73vdWeffXb7/66pqcnZZ5+d1atXZ/bs2es9fu3atbn77rtz4okn5r3vfW/7er9+/fLpT386DzzwQFpbWzf5z+DZZ5/N/PnzM3r06PTu3bt9feDAgfnoRz+aX/7yl5u85+tuvvnmdOrUKaNGjWpf+9u//dvMnDkzL730UvvanXfemST58pe/3OH955xzTofnlUolP/3pT3PcccelUqnkhRdeaH+MGDEiLS0teeSRR/7ieYHti3AFtisf/OAHM3z48AwfPjynnnpqfvGLX+SAAw5oj8gk+dOf/pT+/funR48eHd77+j+9/+lPf2pfq62tzXXXXZc//vGPWb58ea6//vrU1NS86XM7derUIT6T5K//+q+TZIO3sFq6dGleeeWV7Lvvvm96bf/998+6deuyePHijf/h/9fr829o3xdeeCErV67c5H2T5KabbsoHP/jBvPjii1m4cGEWLlyYQw89NKtXr84tt9zSYYZOnTplzz337PD+vffeu8PzpUuX5uWXX84Pf/jD7Lbbbh0eZ5xxRpL//wUyALfDArZrnTp1ytChQ3PVVVflqaeeyoEHHrjJe9x1111JklWrVuWpp556U4y9Wzz11FN5+OGHkyT77LPPm16fPn16zjrrrE3a8/V7wJ522mk5/fTT13vMwIEDN3FSYHslXIHt3muvvZYkWbFiRZJkjz32yOzZs7N8+fIOZ11///vft7/+usceeyyXX355zjjjjMyfPz9nnnlmHn/88dTX13f4jHXr1uUPf/hD+1nWJPnv//7vJNngfVd322237LjjjlmwYMGbXvv973+fTp06pbGxMUnWe5Z3Q16ff0P77rrrrtlpp502er/XTZ8+PV27ds2NN97Y4YtpSfLAAw/kO9/5Tp555pkMGDAge+yxR9atW5c//vGPHSJ34cKFHd632267pUePHlm7dq17wQJvy6UCwHZtzZo1ufvuu1NbW9t+KcAxxxyTtWvX5nvf+16HYydPnpyampqMHDmy/b2jR49O//79c9VVV2XatGlpbm7O+eefv97PeuN+lUol3/ve99K1a9cMGzZsvcd37tw5Rx99dG6//fYOlxM0NzdnxowZ+dCHPpSePXsmSXtovvzyy2/7M/fr1y+HHHJIbrjhhg7HP/HEE7n77rtzzDHHvO0e6zN9+vR8+MMfzqc+9amcdNJJHR5f+cpXkqT9frkjRoxIknz/+9/vsMd3v/vdDs87d+6cUaNG5ac//WmeeOKJN33m0qVL/6JZge2TM67AdmXmzJntZ06ff/75zJgxI0899VT+8R//sT0CjzvuuAwdOjQXXXRRnn766bzvfe/L3Xffndtvvz3nnXde9tprryTJV7/61cyfPz/33HNPevTokYEDB+af//mfc/HFF+ekk07qEIDdunXLnXfemdNPPz2HHXZYZs6cmV/84he58MILs9tuu21w3q9+9auZNWtWPvShD+XLX/5yunTpkh/84Adpa2vLFVdc0X7cIYccks6dO+cb3/hGWlpaUldXl6OOOip9+vRZ777f/OY3M3LkyDQ1NWXMmDF59dVX893vfjf19fVvuofqxvj1r3+dhQsXdvgC2hv91V/9Vd7//vdn+vTpueCCCzJo0KCMGjUqU6ZMyYsvvpjBgwdnzpw57Weh33gGedKkSbn33ntz2GGH5fOf/3wOOOCALFu2LI888khmz56dZcuWbfK8wHaqujc1ANg81nc7rG7dulUOOeSQytVXX11Zt25dh+OXL19eOf/88yv9+/evdO3atbLPPvtUvvnNb7YfN2/evEqXLl063OKqUqlUXnvttcoHPvCBSv/+/SsvvfRSpVL5n9th7bTTTpVFixZVjj766MqOO+5YaWhoqFxyySWVtWvXdnh//s+toCqVSuWRRx6pjBgxotK9e/fKjjvuWBk6dGjlwQcffNPPeM0111Te+973Vjp
"text/plain": [
"<Figure size 800x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"Q1 = df['Age'].quantile(0.25)\n",
"Q3 = df['Age'].quantile(0.75)\n",
"IQR = Q3 - Q1\n",
"\n",
"threshold = 1.5 * IQR\n",
"outliers = (df['Age'] < (Q1 - threshold)) | (df['Age'] > (Q3 + threshold))\n",
"\n",
"median_rating = df['Age'].median()\n",
"df.loc[outliers, 'Age'] = median_rating\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"sns.boxplot(y=df['Age'], color='skyblue')\n",
"plt.title('Boxplot of Age')\n",
"plt.ylabel('Age')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Конструирование признаков с помощью меток"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.preprocessing import LabelEncoder\n",
"\n",
"le = LabelEncoder()\n",
"df['Gender'] = le.fit_transform(df['Gender'])\n",
"df['City'] = le.fit_transform(df['City'])\n",
"df['Dietary Habits'] = le.fit_transform(df['Dietary Habits'])\n",
"df['Degree'] = le.fit_transform(df['Degree'])\n",
"df['Have you ever had suicidal thoughts ?'] = le.fit_transform(df['Have you ever had suicidal thoughts ?'])\n",
"df['Sleep Duration'] = le.fit_transform(df['Sleep Duration'])\n",
"df['Profession'] = le.fit_transform(df['Profession'])\n",
"df['Study Satisfaction'] = le.fit_transform(df['Study Satisfaction'])\n",
"df['Family History of Mental Illness'] = le.fit_transform(df['Family History of Mental Illness'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"разделение на признаки и целевую переменную"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"x = df.drop('Depression', axis=1)\n",
"y = df['Depression']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1) Метод регрессии Лассо\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для Lasso:\n",
"{'alpha': 0.01, 'fit_intercept': False}\n"
]
}
],
"source": [
"from sklearn.linear_model import Lasso\n",
"\n",
"param_grid_lasso = {\n",
" 'alpha': [0.01, 0.1, 1.0, 10.0],\n",
" 'fit_intercept': [True, False],\n",
"}\n",
"\n",
"# Создание объекта GridSearchCV\n",
"grid_search_lasso = GridSearchCV(\n",
" estimator=Lasso(), \n",
" param_grid=param_grid_lasso, \n",
" cv=5, \n",
" scoring='neg_mean_squared_error', \n",
" n_jobs=-1 \n",
")\n",
"\n",
"grid_search_lasso.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для Lasso:\")\n",
"print(grid_search_lasso.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2) Метод градиентного бустинга"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py:540: FitFailedWarning: \n",
"1215 fits failed out of a total of 3645.\n",
"The score on these train-test partitions for these parameters will be set to nan.\n",
"If these failures are not expected, you can try to debug them by setting error_score='raise'.\n",
"\n",
"Below are more details about the failures:\n",
"--------------------------------------------------------------------------------\n",
"978 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
" estimator._validate_params()\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
" validate_parameter_constraints(\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
" raise InvalidParameterError(\n",
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'sqrt', 'log2'} or None. Got 'auto' instead.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"237 fits failed with the following error:\n",
"Traceback (most recent call last):\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_validation.py\", line 888, in _fit_and_score\n",
" estimator.fit(X_train, y_train, **fit_params)\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 1466, in wrapper\n",
" estimator._validate_params()\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\base.py\", line 666, in _validate_params\n",
" validate_parameter_constraints(\n",
" File \"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\utils\\_param_validation.py\", line 95, in validate_parameter_constraints\n",
" raise InvalidParameterError(\n",
"sklearn.utils._param_validation.InvalidParameterError: The 'max_features' parameter of GradientBoostingRegressor must be an int in the range [1, inf), a float in the range (0.0, 1.0], a str among {'log2', 'sqrt'} or None. Got 'auto' instead.\n",
"\n",
" warnings.warn(some_fits_failed_message, FitFailedWarning)\n",
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n",
"e:\\AIM1.5\\Scripts\\Lib\\site-packages\\sklearn\\model_selection\\_search.py:1103: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.18767441 -0.15799837 -0.13080278\n",
" -0.18762913 -0.15792709 -0.13056114 -0.18792038 -0.15737146 -0.130218\n",
" -0.18725961 -0.157967 -0.13047453 -0.18766583 -0.15779565 -0.13094863\n",
" -0.18798705 -0.15693978 -0.13061215 -0.18766317 -0.15746848 -0.13072918\n",
" -0.18864158 -0.15666133 -0.13095037 -0.18817206 -0.15805489 -0.13086126\n",
" -0.18707465 -0.15864932 -0.13104947 -0.18818902 -0.15828572 -0.13063871\n",
" -0.18701628 -0.15853864 -0.13019458 -0.18740927 -0.15836397 -0.13065455\n",
" -0.18768748 -0.15828297 -0.1309458 -0.18845004 -0.15696395 -0.13023062\n",
" -0.18754854 -0.15899615 -0.13061707 -0.18831427 -0.15819939 -0.13096524\n",
" -0.18662963 -0.15815869 -0.13089186 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.1758914 -0.1442684 -0.12093344 -0.1758927 -0.14423731 -0.12084543\n",
" -0.17573339 -0.14419842 -0.12076166 -0.17512045 -0.14435454 -0.1207299\n",
" -0.17669645 -0.14397965 -0.12087019 -0.17605424 -0.1438664 -0.12091068\n",
" -0.17582192 -0.1443651 -0.12097165 -0.17588422 -0.14421003 -0.12081764\n",
" -0.17522742 -0.14424357 -0.12086484 -0.17530986 -0.14433713 -0.12091757\n",
" -0.17565647 -0.14408902 -0.12075918 -0.17561884 -0.14426355 -0.12094066\n",
" -0.17522371 -0.1439869 -0.12099023 -0.17619772 -0.14396131 -0.12079667\n",
" -0.17710789 -0.1448419 -0.12087822 -0.17608534 -0.14416684 -0.12087865\n",
" -0.1754675 -0.1442258 -0.12068226 -0.17611334 -0.14433552 -0.12093556\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.16938321 -0.13763002 -0.11703902\n",
" -0.16953091 -0.13736586 -0.11695779 -0.16881837 -0.1375676 -0.11694438\n",
" -0.16927898 -0.13748177 -0.11689982 -0.16921265 -0.13757375 -0.11682524\n",
" -0.16915872 -0.13727377 -0.11694336 -0.16939766 -0.13734972 -0.1167447\n",
" -0.16924214 -0.1373768 -0.11674816 -0.16918278 -0.13746085 -0.1169816\n",
" -0.16927003 -0.13740063 -0.1169564 -0.16916501 -0.13752074 -0.11687641\n",
" -0.16928973 -0.13751536 -0.11697948 -0.16934836 -0.13727436 -0.11693615\n",
" -0.16912453 -0.13748699 -0.11693425 -0.1692788 -0.13750784 -0.11694655\n",
" -0.16919354 -0.13747437 -0.11708782 -0.16940009 -0.13757749 -0.11700586\n",
" -0.1692801 -0.13725384 -0.11684394 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11606052 -0.1140225 -0.11403709 -0.11627212 -0.1139982 -0.11402075\n",
" -0.11613561 -0.11407941 -0.11420487 -0.11666225 -0.11462523 -0.11431901\n",
" -0.11604817 -0.11456211 -0.11392092 -0.11609343 -0.11394228 -0.11414071\n",
" -0.11611685 -0.11420178 -0.11405459 -0.11594404 -0.11408614 -0.11391662\n",
" -0.11590886 -0.11396465 -0.11389125 -0.11616694 -0.11441846 -0.11417015\n",
" -0.11617368 -0.11429765 -0.1139636 -0.11616763 -0.11433984 -0.11412121\n",
" -0.11625618 -0.11402999 -0.11419791 -0.11613603 -0.114206 -0.11423922\n",
" -0.1160801 -0.11431896 -0.11416734 -0.11608923 -0.11455498 -0.11417448\n",
" -0.11605165 -0.11427773 -0.11392205 -0.11606243 -0.11408421 -0.11395292\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11281447 -0.11245904 -0.11308822\n",
" -0.11256366 -0.11230094 -0.1130767 -0.11282651 -0.1121034 -0.11283479\n",
" -0.11260704 -0.1125136 -0.11288977 -0.11278304 -0.11242278 -0.11268564\n",
" -0.11263359 -0.11236227 -0.11329411 -0.11231603 -0.1124533 -0.11278826\n",
" -0.11291545 -0.11241223 -0.11250702 -0.11246481 -0.11228665 -0.11348916\n",
" -0.11250694 -0.11250274 -0.11298019 -0.11277323 -0.11248601 -0.11301753\n",
" -0.11259486 -0.1124685 -0.11285441 -0.11274424 -0.11232891 -0.11316456\n",
" -0.11274575 -0.11256149 -0.11252293 -0.11293524 -0.11261757 -0.11305628\n",
" -0.11253063 -0.11237109 -0.11278518 -0.1124074 -0.11276905 -0.11296684\n",
" -0.11258689 -0.11228467 -0.11331342 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11292265 -0.11395193 -0.11564599 -0.11244356 -0.11338947 -0.1148266\n",
" -0.11295702 -0.11353862 -0.11510521 -0.11244347 -0.11387967 -0.11512396\n",
" -0.11269802 -0.11364442 -0.1151339 -0.11238356 -0.11364301 -0.11496543\n",
" -0.11229193 -0.11340926 -0.11550744 -0.11215818 -0.11367944 -0.11552889\n",
" -0.11240305 -0.11352309 -0.115412 -0.1128402 -0.11338749 -0.1153551\n",
" -0.11250042 -0.11347275 -0.11548445 -0.11271132 -0.11377527 -0.11558066\n",
" -0.11318598 -0.11325792 -0.11499103 -0.11253099 -0.1129829 -0.11530949\n",
" -0.11239074 -0.11329625 -0.11544761 -0.11262484 -0.11323392 -0.1151936\n",
" -0.11253889 -0.11382403 -0.11511129 -0.11250854 -0.11339898 -0.11536332\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11542253 -0.11498664 -0.11428517\n",
" -0.11503783 -0.11473447 -0.11458687 -0.11483866 -0.1154254 -0.11479037\n",
" -0.11533015 -0.11515195 -0.11460571 -0.11563491 -0.11433835 -0.11437413\n",
" -0.11510849 -0.11472156 -0.11516494 -0.11545009 -0.115001 -0.11479743\n",
" -0.11461761 -0.11537461 -0.11497109 -0.1155148 -0.11567353 -0.11431184\n",
" -0.11546067 -0.11462564 -0.11450721 -0.11511 -0.11487988 -0.11466523\n",
" -0.11585756 -0.11462611 -0.11433121 -0.11538152 -0.11463425 -0.11527088\n",
" -0.11509145 -0.11493588 -0.11484324 -0.11528905 -0.11426327 -0.11476508\n",
" -0.11499562 -0.11451299 -0.11466765 -0.11525918 -0.11469718 -0.11476983\n",
" -0.11467865 -0.1145067 -0.11479425 nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" -0.11352917 -0.1145882 -0.11643688 -0.11418115 -0.11442858 -0.11635549\n",
" -0.11408502 -0.11458383 -0.1163013 -0.1135842 -0.11453566 -0.11575264\n",
" -0.11341863 -0.11481638 -0.11635685 -0.1132144 -0.11438018 -0.11666005\n",
" -0.11311482 -0.11500883 -0.11594984 -0.11409228 -0.11464061 -0.1158012\n",
" -0.11389399 -0.11454081 -0.1157428 -0.11333869 -0.11438896 -0.11676006\n",
" -0.11382523 -0.11443669 -0.11606569 -0.11424726 -0.11464652 -0.11608159\n",
" -0.11396605 -0.11473188 -0.1167532 -0.1136805 -0.11455875 -0.11615814\n",
" -0.11372286 -0.11442829 -0.11590895 -0.1136509 -0.11368863 -0.11660073\n",
" -0.1136605 -0.1141187 -0.11613806 -0.11326355 -0.11427399 -0.11676148\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan nan nan nan\n",
" nan nan nan -0.11573534 -0.11897501 -0.1226239\n",
" -0.1162633 -0.11939573 -0.12255715 -0.11636411 -0.11878021 -0.12306277\n",
" -0.11535113 -0.11813967 -0.1230085 -0.11594119 -0.11812955 -0.12217928\n",
" -0.11523023 -0.11843291 -0.12228252 -0.1159457 -0.11840108 -0.12181337\n",
" -0.11600134 -0.11790484 -0.12203724 -0.11579998 -0.11787918 -0.12317219\n",
" -0.11578704 -0.11837798 -0.12379234 -0.1155279 -0.11865384 -0.12319867\n",
" -0.11597008 -0.11886814 -0.12291788 -0.1162282 -0.11918752 -0.12363613\n",
" -0.11571473 -0.11805225 -0.12250506 -0.11640247 -0.11823175 -0.1226976\n",
" -0.11571549 -0.11813327 -0.12229009 -0.11621545 -0.11793769 -0.1229533\n",
" -0.11528287 -0.1183919 -0.12121653]\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для Gradient Boosting:\n",
"{'learning_rate': 0.1, 'max_depth': 5, 'max_features': 'sqrt', 'min_samples_leaf': 1, 'min_samples_split': 10, 'n_estimators': 100}\n"
]
}
],
"source": [
"\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"\n",
"param_grid_gb = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'learning_rate': [0.01, 0.1, 0.2],\n",
" 'max_depth': [3, 5, 7],\n",
" 'min_samples_split': [2, 5, 10],\n",
" 'min_samples_leaf': [1, 2, 4],\n",
" 'max_features': ['auto', 'sqrt', 'log2']\n",
"}\n",
"\n",
"grid_search_gb = GridSearchCV(\n",
" estimator=GradientBoostingRegressor(),\n",
" param_grid=param_grid_gb,\n",
" cv=5,\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1\n",
")\n",
"\n",
"grid_search_gb.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для Gradient Boosting:\")\n",
"print(grid_search_gb.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3) Метод k-ближайших соседей"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Лучшие гиперпараметры для k-Nearest Neighbors:\n",
"{'algorithm': 'ball_tree', 'n_neighbors': 10, 'p': 1, 'weights': 'distance'}\n"
]
}
],
"source": [
"from sklearn.neighbors import KNeighborsRegressor\n",
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"param_grid_knn = {\n",
" 'n_neighbors': [3, 5, 7, 10],\n",
" 'weights': ['uniform', 'distance'],\n",
" 'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],\n",
" 'p': [1, 2]\n",
"}\n",
"\n",
"grid_search_knn = GridSearchCV(\n",
" estimator=KNeighborsRegressor(),\n",
" param_grid=param_grid_knn,\n",
" cv=5,\n",
" scoring='neg_mean_squared_error',\n",
" n_jobs=-1\n",
")\n",
"\n",
"grid_search_knn.fit(x_train, y_train)\n",
"\n",
"# Вывод лучших гиперпараметров\n",
"print(\"Лучшие гиперпараметры для k-Nearest Neighbors:\")\n",
"print(grid_search_knn.best_params_)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Предсказание на тестовой выборке"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"y_pred = model.predict(x_test)\n",
"y_pred_forest = model_forest.predict(x_test)\n",
"y_pred_lasso = model_lasso.predict(x_test)\n",
"y_pred_gb = model_gb.predict(x_test)\n",
"y_pred_neighbors = model_knn.predict(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Оценка качества модели"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1.\tMSE (Mean Squared Error)\n",
"Среднее значение квадратов разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Squared Error (MSE):\n",
"k-NN: \t\t\t0.213\n",
"Random Forest: \t\t0.118\n",
"Lasso: \t\t\t0.166\n",
"Gradient Boosting: \t0.113\n",
"k-Nearest Neighbors: \t0.326\n"
]
}
],
"source": [
"from sklearn.metrics import mean_squared_error\n",
"import numpy as np\n",
"\n",
"mse1 = mean_squared_error(y_test, y_pred)\n",
"mse2 = mean_squared_error(y_test, y_pred_forest)\n",
"mse3 = mean_squared_error(y_test, y_pred_lasso)\n",
"mse4 = mean_squared_error(y_test, y_pred_gb)\n",
"mse5 = mean_squared_error(y_test, y_pred_neighbors)\n",
"\n",
"mse1_rounded = round(mse1, 3)\n",
"mse2_rounded = round(mse2, 3)\n",
"mse3_rounded = round(mse3, 3)\n",
"mse4_rounded = round(mse4, 3)\n",
"mse5_rounded = round(mse5, 3)\n",
"\n",
"print(\"Mean Squared Error (MSE):\")\n",
"print(f\"k-NN: \\t\\t\\t{mse1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{mse2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{mse3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{mse4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{mse5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.\tMAE\n",
"Среднее значение абсолютных разностей между предсказанными и фактическими значениями. Чем меньше значение, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean Absolute Error (MAE):\n",
"k-NN: \t\t\t0.213\n",
"Random Forest: \t\t0.238\n",
"Lasso: \t\t\t0.366\n",
"Gradient Boosting: \t0.246\n",
"k-Nearest Neighbors: \t0.485\n"
]
}
],
"source": [
"from sklearn.metrics import mean_absolute_error\n",
"\n",
"mae1 = round(mean_absolute_error(y_test, y_pred),3)\n",
"mae2 = round(mean_absolute_error(y_test, y_pred_forest),3)\n",
"mae3 = round(mean_absolute_error(y_test, y_pred_lasso),3)\n",
"mae4 = round(mean_absolute_error(y_test, y_pred_gb),3)\n",
"mae5 = round(mean_absolute_error(y_test, y_pred_neighbors),3)\n",
"print(\"Mean Absolute Error (MAE):\")\n",
"print(f\"k-NN: \\t\\t\\t{mae1}\")\n",
"print(f\"Random Forest: \\t\\t{mae2}\")\n",
"print(f\"Lasso: \\t\\t\\t{mae3}\")\n",
"print(f\"Gradient Boosting: \\t{mae4}\")\n",
"print(f\"k-Nearest Neighbors: \\t{mae5}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.\tR-squared\n",
"Мера, показывающая, насколько хорошо модель объясняет изменчивость данных. Значение находится в диапазоне от 0 до 1, где 1 — идеальное соответствие, а 0 — модель не объясняет данные."
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"R² (R-squared): 0.127933821917115\n",
"\n",
"R² (R-squared):\n",
"k-NN: \t\t\t0.128\n",
"Random Forest: \t\t0.515\n",
"Lasso: \t\t\t0.319\n",
"Gradient Boosting: \t0.537\n",
"k-Nearest Neighbors: \t-0.337\n"
]
}
],
"source": [
"from sklearn.metrics import r2_score\n",
"r2 = r2_score(y_test, y_pred)\n",
"print(f\"R² (R-squared): {r2}\")\n",
"\n",
"r2_1 = r2_score(y_test, y_pred)\n",
"r2_2 = r2_score(y_test, y_pred_forest)\n",
"r2_3 = r2_score(y_test, y_pred_lasso)\n",
"r2_4 = r2_score(y_test, y_pred_gb)\n",
"r2_5 = r2_score(y_test, y_pred_neighbors)\n",
"\n",
"r2_1_rounded = round(r2_1, 3)\n",
"r2_2_rounded = round(r2_2, 3)\n",
"r2_3_rounded = round(r2_3, 3)\n",
"r2_4_rounded = round(r2_4, 3)\n",
"r2_5_rounded = round(r2_5, 3)\n",
"\n",
"print(\"\\nR² (R-squared):\")\n",
"print(f\"k-NN: \\t\\t\\t{r2_1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{r2_2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{r2_3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{r2_4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{r2_5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4.\tRMSE\n",
" Среднее отклонение предсказаний от реальных данных. Чем меньше модуль, тем лучше модель."
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Root Mean Squared Error (RMSE):\n",
"k-NN: \t\t\t0.461\n",
"Random Forest: \t\t0.344\n",
"Lasso: \t\t\t0.407\n",
"Gradient Boosting: \t0.336\n",
"k-Nearest Neighbors: \t0.571\n"
]
}
],
"source": [
"rmse1 = np.sqrt(mse1)\n",
"rmse2 = np.sqrt(mse2)\n",
"rmse3 = np.sqrt(mse3)\n",
"rmse4 = np.sqrt(mse4)\n",
"rmse5 = np.sqrt(mse5)\n",
"\n",
"rmse1_rounded = round(rmse1, 3)\n",
"rmse2_rounded = round(rmse2, 3)\n",
"rmse3_rounded = round(rmse3, 3)\n",
"rmse4_rounded = round(rmse4, 3)\n",
"rmse5_rounded = round(rmse5, 3)\n",
"\n",
"print(\"Root Mean Squared Error (RMSE):\")\n",
"print(f\"k-NN: \\t\\t\\t{rmse1_rounded}\")\n",
"print(f\"Random Forest: \\t\\t{rmse2_rounded}\")\n",
"print(f\"Lasso: \\t\\t\\t{rmse3_rounded}\")\n",
"print(f\"Gradient Boosting: \\t{rmse4_rounded}\")\n",
"print(f\"k-Nearest Neighbors: \\t{rmse5_rounded}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Лучший результат градиентный бустинг и случайный лес.\n",
"Положительные результаты по всем критериям получил случайный лес. Три из четырех положительных результата у градиентного бустинга. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значит, случайный лес наиболее точная и устойчивая стратегия обучения модели. Итоговая модель model_forest."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Также, с помощью применение важности признаков (feature importance) на Случайном лесе, мы вывели основные факторы, вызывающие депрессию:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Feature Importance\n",
"13 Have you ever had suicidal thoughts ? 0.300542\n",
"5 Academic Pressure 0.134276\n",
"0 id 0.087970\n",
"7 CGPA 0.079078\n",
"2 Age 0.066613\n",
"15 Financial Stress 0.066330\n",
"3 City 0.059293\n",
"14 Work/Study Hours 0.052275\n",
"12 Degree 0.049539\n",
"8 Study Satisfaction 0.032944\n",
"11 Dietary Habits 0.026140\n",
"10 Sleep Duration 0.024435\n",
"16 Family History of Mental Illness 0.010547\n",
"1 Gender 0.009627\n",
"4 Profession 0.000372\n",
"9 Job Satisfaction 0.000017\n",
"6 Work Pressure 0.000003\n"
]
}
],
"source": [
"from sklearn.ensemble import RandomForestRegressor\n",
"\n",
"model_rf = RandomForestRegressor(n_estimators=100, random_state=42)\n",
"model_rf.fit(x_train, y_train)\n",
"\n",
"feature_importances = model_rf.feature_importances_\n",
"\n",
"import pandas as pd\n",
"feature_importance_df = pd.DataFrame({\n",
" 'Feature': x.columns,\n",
" 'Importance': feature_importances\n",
"}).sort_values(by='Importance', ascending=False)\n",
"\n",
"print(feature_importance_df)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scripts",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}