AIM-PIbd-31-Belianin-N-N/laboratory_4/lab4.ipynb

5713 lines
823 KiB
Plaintext
Raw Permalink Normal View History

2024-11-14 01:37:01 +04:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Начинаем работу... \n",
"\n",
"Датасет: Продажи домов в округе Кинг "
]
},
{
"cell_type": "code",
"execution_count": 144,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Index(['id', 'date', 'price', 'bedrooms', 'bathrooms', 'sqft_living',\n",
" 'sqft_lot', 'floors', 'waterfront', 'view', 'condition', 'grade',\n",
" 'sqft_above', 'sqft_basement', 'yr_built', 'yr_renovated', 'zipcode',\n",
" 'lat', 'long', 'sqft_living15', 'sqft_lot15'],\n",
" dtype='object')\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"# Установим параметры для вывода\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"random_state = 42\n",
"\n",
"# Подключим датафрейм и выгрузим данные\n",
"df = pd.read_csv(\".//static//csv//kc_house_data.csv\")\n",
"print(df.columns)"
]
},
{
"cell_type": "code",
"execution_count": 145,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>grade</th>\n",
" <th>sqft_above</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>7129300520</td>\n",
" <td>20141013T000000</td>\n",
" <td>221900.0</td>\n",
" <td>3</td>\n",
" <td>1.00</td>\n",
" <td>1180</td>\n",
" <td>5650</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>7</td>\n",
" <td>1180</td>\n",
" <td>0</td>\n",
" <td>1955</td>\n",
" <td>0</td>\n",
" <td>98178</td>\n",
" <td>47.5112</td>\n",
" <td>-122.257</td>\n",
" <td>1340</td>\n",
" <td>5650</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>6414100192</td>\n",
" <td>20141209T000000</td>\n",
" <td>538000.0</td>\n",
" <td>3</td>\n",
" <td>2.25</td>\n",
" <td>2570</td>\n",
" <td>7242</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>7</td>\n",
" <td>2170</td>\n",
" <td>400</td>\n",
" <td>1951</td>\n",
" <td>1991</td>\n",
" <td>98125</td>\n",
" <td>47.7210</td>\n",
" <td>-122.319</td>\n",
" <td>1690</td>\n",
" <td>7639</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>5631500400</td>\n",
" <td>20150225T000000</td>\n",
" <td>180000.0</td>\n",
" <td>2</td>\n",
" <td>1.00</td>\n",
" <td>770</td>\n",
" <td>10000</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>6</td>\n",
" <td>770</td>\n",
" <td>0</td>\n",
" <td>1933</td>\n",
" <td>0</td>\n",
" <td>98028</td>\n",
" <td>47.7379</td>\n",
" <td>-122.233</td>\n",
" <td>2720</td>\n",
" <td>8062</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2487200875</td>\n",
" <td>20141209T000000</td>\n",
" <td>604000.0</td>\n",
" <td>4</td>\n",
" <td>3.00</td>\n",
" <td>1960</td>\n",
" <td>5000</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>7</td>\n",
" <td>1050</td>\n",
" <td>910</td>\n",
" <td>1965</td>\n",
" <td>0</td>\n",
" <td>98136</td>\n",
" <td>47.5208</td>\n",
" <td>-122.393</td>\n",
" <td>1360</td>\n",
" <td>5000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1954400510</td>\n",
" <td>20150218T000000</td>\n",
" <td>510000.0</td>\n",
" <td>3</td>\n",
" <td>2.00</td>\n",
" <td>1680</td>\n",
" <td>8080</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>8</td>\n",
" <td>1680</td>\n",
" <td>0</td>\n",
" <td>1987</td>\n",
" <td>0</td>\n",
" <td>98074</td>\n",
" <td>47.6168</td>\n",
" <td>-122.045</td>\n",
" <td>1800</td>\n",
" <td>7503</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 21 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms sqft_living \\\n",
"0 7129300520 20141013T000000 221900.0 3 1.00 1180 \n",
"1 6414100192 20141209T000000 538000.0 3 2.25 2570 \n",
"2 5631500400 20150225T000000 180000.0 2 1.00 770 \n",
"3 2487200875 20141209T000000 604000.0 4 3.00 1960 \n",
"4 1954400510 20150218T000000 510000.0 3 2.00 1680 \n",
"\n",
" sqft_lot floors waterfront view ... grade sqft_above sqft_basement \\\n",
"0 5650 1.0 0 0 ... 7 1180 0 \n",
"1 7242 2.0 0 0 ... 7 2170 400 \n",
"2 10000 1.0 0 0 ... 6 770 0 \n",
"3 5000 1.0 0 0 ... 7 1050 910 \n",
"4 8080 1.0 0 0 ... 8 1680 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"0 1955 0 98178 47.5112 -122.257 1340 \n",
"1 1951 1991 98125 47.7210 -122.319 1690 \n",
"2 1933 0 98028 47.7379 -122.233 2720 \n",
"3 1965 0 98136 47.5208 -122.393 1360 \n",
"4 1987 0 98074 47.6168 -122.045 1800 \n",
"\n",
" sqft_lot15 \n",
"0 5650 \n",
"1 7639 \n",
"2 8062 \n",
"3 5000 \n",
"4 7503 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 145,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 146,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>condition</th>\n",
" <th>grade</th>\n",
" <th>sqft_above</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>count</th>\n",
" <td>2.161300e+04</td>\n",
" <td>2.161300e+04</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>2.161300e+04</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" <td>21613.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>mean</th>\n",
" <td>4.580302e+09</td>\n",
" <td>5.400881e+05</td>\n",
" <td>3.370842</td>\n",
" <td>2.114757</td>\n",
" <td>2079.899736</td>\n",
" <td>1.510697e+04</td>\n",
" <td>1.494309</td>\n",
" <td>0.007542</td>\n",
" <td>0.234303</td>\n",
" <td>3.409430</td>\n",
" <td>7.656873</td>\n",
" <td>1788.390691</td>\n",
" <td>291.509045</td>\n",
" <td>1971.005136</td>\n",
" <td>84.402258</td>\n",
" <td>98077.939805</td>\n",
" <td>47.560053</td>\n",
" <td>-122.213896</td>\n",
" <td>1986.552492</td>\n",
" <td>12768.455652</td>\n",
" </tr>\n",
" <tr>\n",
" <th>std</th>\n",
" <td>2.876566e+09</td>\n",
" <td>3.671272e+05</td>\n",
" <td>0.930062</td>\n",
" <td>0.770163</td>\n",
" <td>918.440897</td>\n",
" <td>4.142051e+04</td>\n",
" <td>0.539989</td>\n",
" <td>0.086517</td>\n",
" <td>0.766318</td>\n",
" <td>0.650743</td>\n",
" <td>1.175459</td>\n",
" <td>828.090978</td>\n",
" <td>442.575043</td>\n",
" <td>29.373411</td>\n",
" <td>401.679240</td>\n",
" <td>53.505026</td>\n",
" <td>0.138564</td>\n",
" <td>0.140828</td>\n",
" <td>685.391304</td>\n",
" <td>27304.179631</td>\n",
" </tr>\n",
" <tr>\n",
" <th>min</th>\n",
" <td>1.000102e+06</td>\n",
" <td>7.500000e+04</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>290.000000</td>\n",
" <td>5.200000e+02</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>290.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1900.000000</td>\n",
" <td>0.000000</td>\n",
" <td>98001.000000</td>\n",
" <td>47.155900</td>\n",
" <td>-122.519000</td>\n",
" <td>399.000000</td>\n",
" <td>651.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25%</th>\n",
" <td>2.123049e+09</td>\n",
" <td>3.219500e+05</td>\n",
" <td>3.000000</td>\n",
" <td>1.750000</td>\n",
" <td>1427.000000</td>\n",
" <td>5.040000e+03</td>\n",
" <td>1.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>3.000000</td>\n",
" <td>7.000000</td>\n",
" <td>1190.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1951.000000</td>\n",
" <td>0.000000</td>\n",
" <td>98033.000000</td>\n",
" <td>47.471000</td>\n",
" <td>-122.328000</td>\n",
" <td>1490.000000</td>\n",
" <td>5100.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>50%</th>\n",
" <td>3.904930e+09</td>\n",
" <td>4.500000e+05</td>\n",
" <td>3.000000</td>\n",
" <td>2.250000</td>\n",
" <td>1910.000000</td>\n",
" <td>7.618000e+03</td>\n",
" <td>1.500000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>3.000000</td>\n",
" <td>7.000000</td>\n",
" <td>1560.000000</td>\n",
" <td>0.000000</td>\n",
" <td>1975.000000</td>\n",
" <td>0.000000</td>\n",
" <td>98065.000000</td>\n",
" <td>47.571800</td>\n",
" <td>-122.230000</td>\n",
" <td>1840.000000</td>\n",
" <td>7620.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75%</th>\n",
" <td>7.308900e+09</td>\n",
" <td>6.450000e+05</td>\n",
" <td>4.000000</td>\n",
" <td>2.500000</td>\n",
" <td>2550.000000</td>\n",
" <td>1.068800e+04</td>\n",
" <td>2.000000</td>\n",
" <td>0.000000</td>\n",
" <td>0.000000</td>\n",
" <td>4.000000</td>\n",
" <td>8.000000</td>\n",
" <td>2210.000000</td>\n",
" <td>560.000000</td>\n",
" <td>1997.000000</td>\n",
" <td>0.000000</td>\n",
" <td>98118.000000</td>\n",
" <td>47.678000</td>\n",
" <td>-122.125000</td>\n",
" <td>2360.000000</td>\n",
" <td>10083.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>max</th>\n",
" <td>9.900000e+09</td>\n",
" <td>7.700000e+06</td>\n",
" <td>33.000000</td>\n",
" <td>8.000000</td>\n",
" <td>13540.000000</td>\n",
" <td>1.651359e+06</td>\n",
" <td>3.500000</td>\n",
" <td>1.000000</td>\n",
" <td>4.000000</td>\n",
" <td>5.000000</td>\n",
" <td>13.000000</td>\n",
" <td>9410.000000</td>\n",
" <td>4820.000000</td>\n",
" <td>2015.000000</td>\n",
" <td>2015.000000</td>\n",
" <td>98199.000000</td>\n",
" <td>47.777600</td>\n",
" <td>-121.315000</td>\n",
" <td>6210.000000</td>\n",
" <td>871200.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" id price bedrooms bathrooms sqft_living \\\n",
"count 2.161300e+04 2.161300e+04 21613.000000 21613.000000 21613.000000 \n",
"mean 4.580302e+09 5.400881e+05 3.370842 2.114757 2079.899736 \n",
"std 2.876566e+09 3.671272e+05 0.930062 0.770163 918.440897 \n",
"min 1.000102e+06 7.500000e+04 0.000000 0.000000 290.000000 \n",
"25% 2.123049e+09 3.219500e+05 3.000000 1.750000 1427.000000 \n",
"50% 3.904930e+09 4.500000e+05 3.000000 2.250000 1910.000000 \n",
"75% 7.308900e+09 6.450000e+05 4.000000 2.500000 2550.000000 \n",
"max 9.900000e+09 7.700000e+06 33.000000 8.000000 13540.000000 \n",
"\n",
" sqft_lot floors waterfront view condition \\\n",
"count 2.161300e+04 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean 1.510697e+04 1.494309 0.007542 0.234303 3.409430 \n",
"std 4.142051e+04 0.539989 0.086517 0.766318 0.650743 \n",
"min 5.200000e+02 1.000000 0.000000 0.000000 1.000000 \n",
"25% 5.040000e+03 1.000000 0.000000 0.000000 3.000000 \n",
"50% 7.618000e+03 1.500000 0.000000 0.000000 3.000000 \n",
"75% 1.068800e+04 2.000000 0.000000 0.000000 4.000000 \n",
"max 1.651359e+06 3.500000 1.000000 4.000000 5.000000 \n",
"\n",
" grade sqft_above sqft_basement yr_built yr_renovated \\\n",
"count 21613.000000 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean 7.656873 1788.390691 291.509045 1971.005136 84.402258 \n",
"std 1.175459 828.090978 442.575043 29.373411 401.679240 \n",
"min 1.000000 290.000000 0.000000 1900.000000 0.000000 \n",
"25% 7.000000 1190.000000 0.000000 1951.000000 0.000000 \n",
"50% 7.000000 1560.000000 0.000000 1975.000000 0.000000 \n",
"75% 8.000000 2210.000000 560.000000 1997.000000 0.000000 \n",
"max 13.000000 9410.000000 4820.000000 2015.000000 2015.000000 \n",
"\n",
" zipcode lat long sqft_living15 sqft_lot15 \n",
"count 21613.000000 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean 98077.939805 47.560053 -122.213896 1986.552492 12768.455652 \n",
"std 53.505026 0.138564 0.140828 685.391304 27304.179631 \n",
"min 98001.000000 47.155900 -122.519000 399.000000 651.000000 \n",
"25% 98033.000000 47.471000 -122.328000 1490.000000 5100.000000 \n",
"50% 98065.000000 47.571800 -122.230000 1840.000000 7620.000000 \n",
"75% 98118.000000 47.678000 -122.125000 2360.000000 10083.000000 \n",
"max 98199.000000 47.777600 -121.315000 6210.000000 871200.000000 "
]
},
"execution_count": 146,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 147,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"id 0\n",
"date 0\n",
"price 0\n",
"bedrooms 0\n",
"bathrooms 0\n",
"sqft_living 0\n",
"sqft_lot 0\n",
"floors 0\n",
"waterfront 0\n",
"view 0\n",
"condition 0\n",
"grade 0\n",
"sqft_above 0\n",
"sqft_basement 0\n",
"yr_built 0\n",
"yr_renovated 0\n",
"zipcode 0\n",
"lat 0\n",
"long 0\n",
"sqft_living15 0\n",
"sqft_lot15 0\n",
"dtype: int64\n",
"id False\n",
"date False\n",
"price False\n",
"bedrooms False\n",
"bathrooms False\n",
"sqft_living False\n",
"sqft_lot False\n",
"floors False\n",
"waterfront False\n",
"view False\n",
"condition False\n",
"grade False\n",
"sqft_above False\n",
"sqft_basement False\n",
"yr_built False\n",
"yr_renovated False\n",
"zipcode False\n",
"lat False\n",
"long False\n",
"sqft_living15 False\n",
"sqft_lot15 False\n",
"dtype: bool\n"
]
}
],
"source": [
"# Процент пропущенных значений признаков\n",
"for i in df.columns:\n",
" null_rate = df[i].isnull().sum() / len(df) * 100\n",
" if null_rate > 0:\n",
" print(f'{i} Процент пустых значений: %{null_rate:.2f}')\n",
"\n",
"print(df.isnull().sum())\n",
"\n",
"print(df.isnull().any())"
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"id int64\n",
"date object\n",
"price float64\n",
"bedrooms int64\n",
"bathrooms float64\n",
"sqft_living int64\n",
"sqft_lot int64\n",
"floors float64\n",
"waterfront int64\n",
"view int64\n",
"condition int64\n",
"grade int64\n",
"sqft_above int64\n",
"sqft_basement int64\n",
"yr_built int64\n",
"yr_renovated int64\n",
"zipcode int64\n",
"lat float64\n",
"long float64\n",
"sqft_living15 int64\n",
"sqft_lot15 int64\n",
"dtype: object"
]
},
"execution_count": 148,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Проверка типов столбцов\n",
"df.dtypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Выбор бизнес-целей \n",
"Для датасета недвижимости предлагаются две бизнес-цели:\n",
"\n",
"*Задача регрессии* предсказание цены дома (price). Это может помочь риэлторам и аналитикам определить справедливую рыночную стоимость недвижимости. \n",
"\n",
"*Задача классификации* определение вероятности того, что цена дома будет выше/ниже медианы рынка. Классифицировать дома по ценовым категориям (например, низкая, средняя, высокая цена). Это может помочь определить, какие дома популярны у покупателей.\n",
"\n",
"## Определение достижимого уровня качества модели \n",
"Для регрессии и классификации мы выберем метрики: \n",
"\n",
"Для регрессии будем использовать метрики MAE (средняя абсолютная ошибка) и R^2 (коэффициент детерминации), стремясь к MAE ниже 10% от средней цены. А классификация будте ориентироваться на метрики accuracy и F1-score при целевом значении accuracy около 80%.\n",
"\n",
"## Ориентир для каждой задачи\n",
"Для регрессии ориентиром будет медианная цена (price.median()), так как это стабильное значение. Для классификации ориентируемся на среднюю вероятность предсказания класса выше медианы.\n",
"\n",
"## Анализ алгоритмов машинного обучения \n",
"Рассмотрим для задачи регрессии:\n",
"\n",
"*Линейная регрессия:* подходит для простых линейных зависимостей. \n",
"*Дерево решений:* учитывает нелинейные зависимости, может учесть сложные закономерности. \n",
"*Случайный лес:* ансамблевый метод, обобщающий данные и эффективно обрабатывающий выбросы. \n",
"\n",
"Для задачи классификации: \n",
"\n",
"*Логистическая регрессия:* простая модель, подходящая для бинарной классификации. \n",
"*Метод опорных векторов (SVM):* работает хорошо на данных с четкими разделениями. \n",
"*Градиентный бустинг:* подходит для сложных и высокоразмерных данных, обеспечивает высокую точность. \n",
"\n",
"## Выбор моделей \n",
"Выбираем по три модели для каждой задачи:\n",
"\n",
"*Регрессия:* Линейная регрессия, Дерево решений, Случайный лес. \n",
"*Классификация:* Логистическая регрессия, Метод опорных векторов (SVM), Градиентный бустинг. \n",
"\n",
"\n",
"## Построение конвейера и визуализации \n",
"Теперь напишем код для загрузки данных, анализа и подготовки моделей с визуализацией результатов.\n",
"\n",
"\n",
"# Начнём с задачи классификации\n",
"\n",
"Целевой признак --> above_median_price\n",
"\n",
"Формируем выборки. Разделяем набор данных на обучающую и тестовые выборки (80/20) для задачи классификации"
]
},
{
"cell_type": "code",
"execution_count": 149,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20962</th>\n",
" <td>1278000210</td>\n",
" <td>20150311T000000</td>\n",
" <td>110000.0</td>\n",
" <td>2</td>\n",
" <td>1.00</td>\n",
" <td>828</td>\n",
" <td>4524</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1968</td>\n",
" <td>2007</td>\n",
" <td>98001</td>\n",
" <td>47.2655</td>\n",
" <td>-122.244</td>\n",
" <td>828</td>\n",
" <td>5402</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12284</th>\n",
" <td>2193300390</td>\n",
" <td>20140923T000000</td>\n",
" <td>624000.0</td>\n",
" <td>4</td>\n",
" <td>3.25</td>\n",
" <td>2810</td>\n",
" <td>11250</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1130</td>\n",
" <td>1980</td>\n",
" <td>0</td>\n",
" <td>98052</td>\n",
" <td>47.6920</td>\n",
" <td>-122.099</td>\n",
" <td>2110</td>\n",
" <td>11250</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7343</th>\n",
" <td>4289900005</td>\n",
" <td>20141230T000000</td>\n",
" <td>1535000.0</td>\n",
" <td>4</td>\n",
" <td>3.25</td>\n",
" <td>2850</td>\n",
" <td>4100</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>...</td>\n",
" <td>1030</td>\n",
" <td>1908</td>\n",
" <td>2003</td>\n",
" <td>98122</td>\n",
" <td>47.6147</td>\n",
" <td>-122.285</td>\n",
" <td>2130</td>\n",
" <td>4200</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14247</th>\n",
" <td>316000145</td>\n",
" <td>20150325T000000</td>\n",
" <td>235000.0</td>\n",
" <td>4</td>\n",
" <td>1.00</td>\n",
" <td>1360</td>\n",
" <td>7132</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1941</td>\n",
" <td>0</td>\n",
" <td>98168</td>\n",
" <td>47.5054</td>\n",
" <td>-122.301</td>\n",
" <td>1280</td>\n",
" <td>7175</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16670</th>\n",
" <td>629400480</td>\n",
" <td>20140619T000000</td>\n",
" <td>775000.0</td>\n",
" <td>4</td>\n",
" <td>2.75</td>\n",
" <td>3010</td>\n",
" <td>15992</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1996</td>\n",
" <td>0</td>\n",
" <td>98075</td>\n",
" <td>47.5895</td>\n",
" <td>-121.994</td>\n",
" <td>3330</td>\n",
" <td>12333</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>88</th>\n",
" <td>1332700270</td>\n",
" <td>20140519T000000</td>\n",
" <td>215000.0</td>\n",
" <td>2</td>\n",
" <td>2.25</td>\n",
" <td>1610</td>\n",
" <td>2040</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1979</td>\n",
" <td>0</td>\n",
" <td>98056</td>\n",
" <td>47.5180</td>\n",
" <td>-122.194</td>\n",
" <td>1950</td>\n",
" <td>2025</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15031</th>\n",
" <td>7129303070</td>\n",
" <td>20140820T000000</td>\n",
" <td>735000.0</td>\n",
" <td>4</td>\n",
" <td>2.75</td>\n",
" <td>3040</td>\n",
" <td>2415</td>\n",
" <td>2.0</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1966</td>\n",
" <td>0</td>\n",
" <td>98118</td>\n",
" <td>47.5188</td>\n",
" <td>-122.256</td>\n",
" <td>2620</td>\n",
" <td>2433</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5234</th>\n",
" <td>2432000130</td>\n",
" <td>20150414T000000</td>\n",
" <td>675000.0</td>\n",
" <td>3</td>\n",
" <td>1.75</td>\n",
" <td>1660</td>\n",
" <td>9549</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1956</td>\n",
" <td>0</td>\n",
" <td>98033</td>\n",
" <td>47.6503</td>\n",
" <td>-122.198</td>\n",
" <td>2090</td>\n",
" <td>9549</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19980</th>\n",
" <td>774100475</td>\n",
" <td>20140627T000000</td>\n",
" <td>415000.0</td>\n",
" <td>3</td>\n",
" <td>2.75</td>\n",
" <td>2600</td>\n",
" <td>64626</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2009</td>\n",
" <td>0</td>\n",
" <td>98014</td>\n",
" <td>47.7185</td>\n",
" <td>-121.405</td>\n",
" <td>1740</td>\n",
" <td>64626</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3671</th>\n",
" <td>8847400115</td>\n",
" <td>20140723T000000</td>\n",
" <td>590000.0</td>\n",
" <td>3</td>\n",
" <td>2.00</td>\n",
" <td>2420</td>\n",
" <td>208652</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2005</td>\n",
" <td>0</td>\n",
" <td>98010</td>\n",
" <td>47.3666</td>\n",
" <td>-121.978</td>\n",
" <td>3180</td>\n",
" <td>212137</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>17290 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms \\\n",
"20962 1278000210 20150311T000000 110000.0 2 1.00 \n",
"12284 2193300390 20140923T000000 624000.0 4 3.25 \n",
"7343 4289900005 20141230T000000 1535000.0 4 3.25 \n",
"14247 316000145 20150325T000000 235000.0 4 1.00 \n",
"16670 629400480 20140619T000000 775000.0 4 2.75 \n",
"... ... ... ... ... ... \n",
"88 1332700270 20140519T000000 215000.0 2 2.25 \n",
"15031 7129303070 20140820T000000 735000.0 4 2.75 \n",
"5234 2432000130 20150414T000000 675000.0 3 1.75 \n",
"19980 774100475 20140627T000000 415000.0 3 2.75 \n",
"3671 8847400115 20140723T000000 590000.0 3 2.00 \n",
"\n",
" sqft_living sqft_lot floors waterfront view ... sqft_basement \\\n",
"20962 828 4524 1.0 0 0 ... 0 \n",
"12284 2810 11250 1.0 0 0 ... 1130 \n",
"7343 2850 4100 2.0 0 3 ... 1030 \n",
"14247 1360 7132 1.5 0 0 ... 0 \n",
"16670 3010 15992 2.0 0 0 ... 0 \n",
"... ... ... ... ... ... ... ... \n",
"88 1610 2040 2.0 0 0 ... 0 \n",
"15031 3040 2415 2.0 1 4 ... 0 \n",
"5234 1660 9549 1.0 0 0 ... 0 \n",
"19980 2600 64626 1.5 0 0 ... 0 \n",
"3671 2420 208652 1.5 0 0 ... 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"20962 1968 2007 98001 47.2655 -122.244 828 \n",
"12284 1980 0 98052 47.6920 -122.099 2110 \n",
"7343 1908 2003 98122 47.6147 -122.285 2130 \n",
"14247 1941 0 98168 47.5054 -122.301 1280 \n",
"16670 1996 0 98075 47.5895 -121.994 3330 \n",
"... ... ... ... ... ... ... \n",
"88 1979 0 98056 47.5180 -122.194 1950 \n",
"15031 1966 0 98118 47.5188 -122.256 2620 \n",
"5234 1956 0 98033 47.6503 -122.198 2090 \n",
"19980 2009 0 98014 47.7185 -121.405 1740 \n",
"3671 2005 0 98010 47.3666 -121.978 3180 \n",
"\n",
" sqft_lot15 above_median_price price_category \n",
"20962 5402 0 0 \n",
"12284 11250 1 1 \n",
"7343 4200 1 2 \n",
"14247 7175 0 0 \n",
"16670 12333 1 2 \n",
"... ... ... ... \n",
"88 2025 0 0 \n",
"15031 2433 1 2 \n",
"5234 9549 1 1 \n",
"19980 64626 0 1 \n",
"3671 212137 1 1 \n",
"\n",
"[17290 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_median_price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20962</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12284</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7343</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14247</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16670</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>88</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15031</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5234</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19980</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3671</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>17290 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_median_price\n",
"20962 0\n",
"12284 1\n",
"7343 1\n",
"14247 0\n",
"16670 1\n",
"... ...\n",
"88 0\n",
"15031 1\n",
"5234 1\n",
"19980 0\n",
"3671 1\n",
"\n",
"[17290 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>11592</th>\n",
" <td>2028701000</td>\n",
" <td>20140529T000000</td>\n",
" <td>635200.0</td>\n",
" <td>4</td>\n",
" <td>1.75</td>\n",
" <td>1640</td>\n",
" <td>4240</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>720</td>\n",
" <td>1921</td>\n",
" <td>0</td>\n",
" <td>98117</td>\n",
" <td>47.6766</td>\n",
" <td>-122.368</td>\n",
" <td>1300</td>\n",
" <td>4240</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8984</th>\n",
" <td>9406500530</td>\n",
" <td>20140912T000000</td>\n",
" <td>249000.0</td>\n",
" <td>2</td>\n",
" <td>2.00</td>\n",
" <td>1090</td>\n",
" <td>1357</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1990</td>\n",
" <td>0</td>\n",
" <td>98028</td>\n",
" <td>47.7526</td>\n",
" <td>-122.244</td>\n",
" <td>1078</td>\n",
" <td>1318</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8280</th>\n",
" <td>8097000330</td>\n",
" <td>20140721T000000</td>\n",
" <td>359950.0</td>\n",
" <td>3</td>\n",
" <td>2.75</td>\n",
" <td>2540</td>\n",
" <td>8604</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1991</td>\n",
" <td>0</td>\n",
" <td>98092</td>\n",
" <td>47.3209</td>\n",
" <td>-122.185</td>\n",
" <td>2260</td>\n",
" <td>7438</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>792</th>\n",
" <td>8081020370</td>\n",
" <td>20140709T000000</td>\n",
" <td>1355000.0</td>\n",
" <td>4</td>\n",
" <td>3.50</td>\n",
" <td>3550</td>\n",
" <td>11000</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>1290</td>\n",
" <td>1999</td>\n",
" <td>0</td>\n",
" <td>98006</td>\n",
" <td>47.5506</td>\n",
" <td>-122.134</td>\n",
" <td>4100</td>\n",
" <td>10012</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10371</th>\n",
" <td>7518507580</td>\n",
" <td>20150502T000000</td>\n",
" <td>581000.0</td>\n",
" <td>2</td>\n",
" <td>1.00</td>\n",
" <td>1170</td>\n",
" <td>4080</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1909</td>\n",
" <td>0</td>\n",
" <td>98117</td>\n",
" <td>47.6784</td>\n",
" <td>-122.386</td>\n",
" <td>1560</td>\n",
" <td>4586</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16733</th>\n",
" <td>7212650950</td>\n",
" <td>20140708T000000</td>\n",
" <td>336000.0</td>\n",
" <td>4</td>\n",
" <td>2.50</td>\n",
" <td>2530</td>\n",
" <td>8169</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1993</td>\n",
" <td>0</td>\n",
" <td>98003</td>\n",
" <td>47.2634</td>\n",
" <td>-122.312</td>\n",
" <td>2220</td>\n",
" <td>8013</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13151</th>\n",
" <td>4365200620</td>\n",
" <td>20150312T000000</td>\n",
" <td>394000.0</td>\n",
" <td>3</td>\n",
" <td>1.00</td>\n",
" <td>1450</td>\n",
" <td>7930</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>300</td>\n",
" <td>1923</td>\n",
" <td>0</td>\n",
" <td>98126</td>\n",
" <td>47.5212</td>\n",
" <td>-122.371</td>\n",
" <td>1040</td>\n",
" <td>7740</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11667</th>\n",
" <td>4083304355</td>\n",
" <td>20150318T000000</td>\n",
" <td>675000.0</td>\n",
" <td>4</td>\n",
" <td>1.75</td>\n",
" <td>1530</td>\n",
" <td>3615</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1913</td>\n",
" <td>0</td>\n",
" <td>98103</td>\n",
" <td>47.6529</td>\n",
" <td>-122.334</td>\n",
" <td>1650</td>\n",
" <td>4200</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3683</th>\n",
" <td>2891100820</td>\n",
" <td>20140825T000000</td>\n",
" <td>213500.0</td>\n",
" <td>3</td>\n",
" <td>1.00</td>\n",
" <td>1220</td>\n",
" <td>6000</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1968</td>\n",
" <td>0</td>\n",
" <td>98002</td>\n",
" <td>47.3245</td>\n",
" <td>-122.209</td>\n",
" <td>1420</td>\n",
" <td>6000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12059</th>\n",
" <td>952000640</td>\n",
" <td>20141027T000000</td>\n",
" <td>715000.0</td>\n",
" <td>3</td>\n",
" <td>1.50</td>\n",
" <td>1670</td>\n",
" <td>5060</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1925</td>\n",
" <td>0</td>\n",
" <td>98126</td>\n",
" <td>47.5671</td>\n",
" <td>-122.379</td>\n",
" <td>1670</td>\n",
" <td>5118</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4323 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms \\\n",
"11592 2028701000 20140529T000000 635200.0 4 1.75 \n",
"8984 9406500530 20140912T000000 249000.0 2 2.00 \n",
"8280 8097000330 20140721T000000 359950.0 3 2.75 \n",
"792 8081020370 20140709T000000 1355000.0 4 3.50 \n",
"10371 7518507580 20150502T000000 581000.0 2 1.00 \n",
"... ... ... ... ... ... \n",
"16733 7212650950 20140708T000000 336000.0 4 2.50 \n",
"13151 4365200620 20150312T000000 394000.0 3 1.00 \n",
"11667 4083304355 20150318T000000 675000.0 4 1.75 \n",
"3683 2891100820 20140825T000000 213500.0 3 1.00 \n",
"12059 952000640 20141027T000000 715000.0 3 1.50 \n",
"\n",
" sqft_living sqft_lot floors waterfront view ... sqft_basement \\\n",
"11592 1640 4240 1.0 0 0 ... 720 \n",
"8984 1090 1357 2.0 0 0 ... 0 \n",
"8280 2540 8604 2.0 0 0 ... 0 \n",
"792 3550 11000 1.0 0 2 ... 1290 \n",
"10371 1170 4080 1.0 0 0 ... 0 \n",
"... ... ... ... ... ... ... ... \n",
"16733 2530 8169 2.0 0 0 ... 0 \n",
"13151 1450 7930 1.0 0 0 ... 300 \n",
"11667 1530 3615 1.5 0 0 ... 0 \n",
"3683 1220 6000 1.0 0 0 ... 0 \n",
"12059 1670 5060 2.0 0 2 ... 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"11592 1921 0 98117 47.6766 -122.368 1300 \n",
"8984 1990 0 98028 47.7526 -122.244 1078 \n",
"8280 1991 0 98092 47.3209 -122.185 2260 \n",
"792 1999 0 98006 47.5506 -122.134 4100 \n",
"10371 1909 0 98117 47.6784 -122.386 1560 \n",
"... ... ... ... ... ... ... \n",
"16733 1993 0 98003 47.2634 -122.312 2220 \n",
"13151 1923 0 98126 47.5212 -122.371 1040 \n",
"11667 1913 0 98103 47.6529 -122.334 1650 \n",
"3683 1968 0 98002 47.3245 -122.209 1420 \n",
"12059 1925 0 98126 47.5671 -122.379 1670 \n",
"\n",
" sqft_lot15 above_median_price price_category \n",
"11592 4240 1 1 \n",
"8984 1318 0 0 \n",
"8280 7438 0 1 \n",
"792 10012 1 2 \n",
"10371 4586 1 1 \n",
"... ... ... ... \n",
"16733 8013 0 1 \n",
"13151 7740 0 1 \n",
"11667 4200 1 1 \n",
"3683 6000 0 0 \n",
"12059 5118 1 2 \n",
"\n",
"[4323 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>above_median_price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>11592</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8984</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8280</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>792</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10371</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16733</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13151</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11667</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3683</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12059</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4323 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" above_median_price\n",
"11592 1\n",
"8984 0\n",
"8280 0\n",
"792 1\n",
"10371 1\n",
"... ...\n",
"16733 0\n",
"13151 0\n",
"11667 1\n",
"3683 0\n",
"12059 1\n",
"\n",
"[4323 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"id int64\n",
"date object\n",
"price float64\n",
"bedrooms int64\n",
"bathrooms float64\n",
"sqft_living int64\n",
"sqft_lot int64\n",
"floors float64\n",
"waterfront int64\n",
"view int64\n",
"condition int64\n",
"grade int64\n",
"sqft_above int64\n",
"sqft_basement int64\n",
"yr_built int64\n",
"yr_renovated int64\n",
"zipcode int64\n",
"lat float64\n",
"long float64\n",
"sqft_living15 int64\n",
"sqft_lot15 int64\n",
"above_median_price int64\n",
"price_category category\n",
"dtype: object\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1sAAAIjCAYAAAD1OgEdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB1RklEQVR4nO3deXwTdf7H8ffk7F2gpS3lFpBLQCmK9UBEFBFdXVldFRVBRVdQgfVY1huvxRMPvFYFXeGn4K2oiCh4oSJaBURERItACwV6t0mbzO+PNmlDy1Xapklez8cjD5qZycwnacR58/3OZwzTNE0BAAAAABqVJdgFAAAAAEA4ImwBAAAAQBMgbAEAAABAEyBsAQAAAEATIGwBAAAAQBMgbAEAAABAEyBsAQAAAEATIGwBAAAAQBMgbAEAAABAEyBsAQAAAEATIGwBaJHmzJkjwzD8j6ioKB166KGaNGmScnNzg10eAADAPtmCXQAA7M306dPVtWtXlZeX6/PPP9eTTz6p9957T6tXr1ZMTEywywMAANgjwhaAFm3kyJEaNGiQJOmyyy5TUlKSHnroIb311ls6//zzg1wdAADAnjGNEEBIGTZsmCRp48aNkqSdO3fquuuuU79+/RQXF6eEhASNHDlSP/zwQ53XlpeX6/bbb9ehhx6qqKgotWvXTmeffbY2bNggSfr9998Dpi7u/hg6dKh/X0uXLpVhGHrllVf073//W2lpaYqNjdVf/vIXbdq0qc6xv/76a5166qlKTExUTEyMTjjhBH3xxRf1vsehQ4fWe/zbb7+9zrYvvfSSMjIyFB0drTZt2ui8886r9/h7e2+1eb1ezZw5U3379lVUVJRSU1N1xRVXaNeuXQHbdenSRaeffnqd40yaNKnOPuur/f7776/zmUqSy+XSbbfdpu7du8vpdKpjx4664YYb5HK56v2sahs6dKgOO+ywOssfeOABGYah33//PWB5fn6+Jk+erI4dO8rpdKp79+6aMWOGvF6vfxvf5/bAAw/U2e9hhx1Wp/76GIahSZMm1Vl++umnq0uXLnVqPeaYY5SUlKTo6GhlZGTo1Vdf3ecxpAN//++//76OP/54xcbGKj4+XqNGjdKaNWvq3XeXLl3q/e7MmTPHv83u3zG73a4uXbro+uuvl9vt9m/nmyJcux6v16v+/fvX2ecFF1ygpKQkrV+/fq+vX7BggSwWi55++mn/sksuuaTO57tp0yZFR0fXeb3v/U2ePLnOex8xYoQMw6jzfd+2bZsuvfRSpaamKioqSgMGDNALL7xQ5/Ver1ePPPKI+vXrp6ioKLVt21annnqqvv32W0na6985tf8b8f2ds7/fBwAtAyNbAEKKLxglJSVJkn777Te9+eabOuecc9S1a1fl5ubq6aef1gknnKCffvpJ6enpkiSPx6PTTz9dS5Ys0Xnnnadrr71WRUVFWrx4sVavXq1u3br5j3H++efrtNNOCzjutGnT6q3n7rvvlmEYuvHGG7Vt2zbNnDlTw4cPV1ZWlqKjoyVJH3/8sUaOHKmMjAzddtttslgsmj17toYNG6bPPvtMRx11VJ39dujQQffee68kqbi4WP/4xz/qPfYtt9yic889V5dddpm2b9+uxx57TEOGDNH333+vVq1a1XnNhAkTdPzxx0uSXn/9db3xxhsB66+44grNmTNH48aN0zXXXKONGzfq8ccf1/fff68vvvhCdru93s/hQOTn5/vfW21er1d/+ctf9Pnnn2vChAnq3bu3Vq1apYcffli//PKL3nzzzYM+tk9paalOOOEEbd68WVdccYU6deqkL7/8UtOmTdPWrVs1c+bMRjvWgXjkkUf0l7/8RWPGjJHb7dbLL7+sc845R++++65GjRrVaMf53//+p7Fjx2rEiBGaMWOGSktL9eSTT+q4447T999/XyekSNLhhx+uf/7zn5Kq/rHj1ltvrXffvu+Yy+XSokWL9MADDygqKkp33nnnXutZtWpVneXPP/+8hg0bplGjRunrr79W69at62zzzTffaOzYsZoyZYquuOKKvb7vW2+9VeXl5fWui4qK0ty5c3X//ff7v+d//vmnlixZoqioqIBty8rKNHToUP3666+aNGmSunbtqgULFuiSSy5Rfn6+rr32Wv+2l156qebMmaORI0fqsssuU2VlpT777DN99dVXGjRokP73v//5t/3ss8/0zDPP6OGHH1ZycrIkKTU1da/vCUALZwJACzR79mxTkvnRRx+Z27dvNzdt2mS+/PLLZlJSkhkdHW3++eefpmmaZnl5uenxeAJeu3HjRtPpdJrTp0/3L3v++edNSeZDDz1U51her9f/Oknm/fffX2ebvn37mieccIL/+SeffGJKMtu3b28WFhb6l8+fP9+UZD7yyCP+fffo0cMcMWKE/zimaZqlpaVm165dzZNPPrnOsY455hjzsMMO8z/fvn27Kcm87bbb/Mt+//1302q1mnfffXfAa1etWmXabLY6y9evX29KMl944QX/sttuu82s/b+Bzz77zJRkzp07N+C1H3zwQZ3lnTt3NkeNGlWn9okTJ5q7/69l99pvuOEGMyUlxczIyAj4TP/3v/+ZFovF/OyzzwJe/9RTT5mSzC+++KLO8Wo74YQTzL59+9ZZfv/995uSzI0bN/qX3XnnnWZsbKz5yy+/BGz7r3/9y7RarWZ2drZpmgf2ndgTSebEiRPrLB81apTZuXPngGWlpaUBz91ut3nYYYeZw4YN2+dx9vf9FxUVma1atTIvv/zygO1ycnLMxMTEOstN0zTT09PN008/3f98xYoVpiRz9uzZ/mW+z6r2Mt9rTzvtNP9z33/bvnrKy8vNTp06mSNHjqz39bm5uWaXLl3ME0880XS73QGvz87ONtPS0sy//OUvdf4eGDt2bMDnu3r1atNisfiPU/v70LlzZ/Pkk082k5OTzVdffdW//M477zSPOeaYOt/3mTNnmpLMl156yb/M7XabmZmZZlxcnP/vhI8//tiUZF5zzTV1PtPafx/s6bOpzfd3zoIFC+qsA9ByMY0QQIs2fPhwtW3bVh07dtR5552nuLg4vfHGG2rfvr0kyel0ymKp+qvM4/Fox44diouLU8+ePfXdd9/59/Paa68pOTlZV199dZ1j7D7t7UBcfPHFio+P9z//29/+pnbt2um9996TJGVlZWn9+vW64IILtGPHDuXl5SkvL08lJSU66aST9OmnnwZMW5Oqpjvu/i/pu3v99dfl9Xp17rnn+veZl5entLQ09ejRQ5988knA9r5pXE6nc4/7XLBggRITE3XyyScH7DMjI0NxcXF19llRURGwXV5e3h5HDXw2b96sxx57TLfccovi4uLqHL93797q1atXwD59U0d3P/7BWLBggY4//ni1bt064FjDhw+Xx+PRp59+GrB9aWlpnffq8Xj2+3jl5eV1Xl9RUVFnO99oqCTt2rVLBQUFOv744wO+y3vj8XjqHKe0tDRgm8WLFys/P1/nn39+wHZWq1WDBw+u93Pen++kT3FxsfLy8rR582Y988wzysnJ0UknnbTH7WfNmqUdO3botttuq3d9SkqKFi5cqK+//lpXXXVVwHHOOOMMJScna968ef6/B/Zk2rRpGjhwoM4555x61zscDo0ZM0azZ8/2L/ON8u7uvffeU1paWsB1o3a7Xddcc42Ki4u1bNkySVV/7xiGUe97a+jfO0VFRcrLy1N+fn6DXg+geTGNEECLNmvWLB166KGy2WxKTU1Vz549A06qfNdDPPHEE9q4cWPACbBvqqFUNf2wZ8+estka96+9Hj16BDw3DEPdu3f3Xw/iu9Zk7Nixe9xHQUFBwPSovLy8Ovvd3fr162Wa5h632326n+/EbPeAs/s+CwoKlJKSUu/6bdu2BTz/8MMP1bZt273WubvbbrtN6enpuuKKK+pce7J+/XqtXbt2j/vc/fgHY/369frxxx/3+1i33XZbvSfM+zvF67nnntNzzz1XZ3nnzp0Dnr/77ru66667lJWVFXCd2v6emP/888/7/J34vpO+ELu7hISEgOcej0f5+flKTEzcrxquvvrqgH/UGDd
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0EAAAIjCAYAAADFthA8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAACGn0lEQVR4nO3de1yTdf8/8NcYZ0UOngamxjwWukQrQwqtTPOUnTy1UtOyG6gsu++7zOGiILO0M1BqYQmlWZnZbZn6TdcBz9IU8zzTFDyEHAQ5uF2/P/xtbTJg4OBi1/V6Ph48ZNf13va+2MDrvc/nen8UgiAIICIiIiIikgkvsRMgIiIiIiJqTiyCiIiIiIhIVlgEERERERGRrLAIIiIiIiIiWWERREREREREssIiiIiIiIiIZIVFEBERERERyQqLICIiIiIikhUWQUREREREJCssgoiIiIiISFZYBBERERERkaywCCKSsQ8++ADDhw9Hx44d4ePjA5VKhcGDB+PTTz+FxWIROz0iIiKiJqEQBEEQOwkiEkdMTAzCw8Nxxx13oE2bNigqKsKWLVuwfPlyTJgwAZ9//rnYKRIRERG5HYsgIhmrrq6Gj49Pje1PPfUU3n//fZhMJlx77bXNnxgRERFRE+J0OCIZc1YAAbAVPl5e//yJWL16NUaNGoWIiAj4+fmhW7dueOWVV2A2mx3uO2TIECgUCttXu3btMGrUKOzdu9chTqFQ4KWXXnLY9sYbb0ChUGDIkCEO2ysqKvDSSy+hZ8+e8Pf3R3h4OO6//34cOXIEAHDs2DEoFAosXbrU4X6JiYlQKBSYOnWqbdvSpUuhUCjg6+uLs2fPOsTn5OTY8t6xY4fDvpUrV2LAgAEICAhAu3bt8PDDD+PkyZM1fnb79+/H+PHj0b59ewQEBKBXr16YM2cOAOCll15y+Nk4+9q0aZPt59inT58aj++K2u67YMECKBQKHDt2zGF7UVERnnnmGXTu3Bl+fn7o3r075s+f7zAl0vozXrBgQY3H7dOnj8NrtmnTJigUCnz55Ze15jh16lSXC+z09HRERUXBz88PERERSExMRFFRkcPx1vdzrcuQIUNqvOdSU1Ph5eWFzz77zGG7q+8DALXmYv/zd/X3wPreudK1117r8P4GXHs9AcBiseCdd95B37594e/vj/bt2+Puu++2vffr+5la87O+3tYvPz8/9OzZE/PmzYP956x//vknEhIS0KtXLwQEBKBt27YYN25cjfdjberLt76cr3y/LViwAIMGDULbtm0REBCAAQMG1Pqetf7dqO1n0JCffUN+l4io6XiLnQARia+oqAiXLl1CaWkpdu7ciQULFmDixIno0qWLLWbp0qVo3bo1Zs2ahdatW+P//u//MHfuXJSUlOCNN95weLzevXtjzpw5EAQBR44cwZtvvomRI0fi+PHjdeYwb968GtvNZjNGjx6NjRs3YuLEiZg5cyZKS0uxfv167N27F926dXP6eIcPH8bixYtrfT6lUomsrCw8++yztm2ZmZnw9/dHRUWFQ+zSpUvx6KOP4qabbsK8efNw+vRpvPPOO/j111+xe/duhISEAACMRiNuu+02+Pj4YMaMGbj22mtx5MgRrFmzBqmpqbj//vvRvXt32+M+++yzuO666zBjxgzbtuuuu67WnJtCeXk5Bg8ejJMnT+KJJ55Aly5d8Ntvv2H27NnIz8/H22+/3az5XOmll15CcnIyhg4divj4eBw4cAAZGRnYvn07fv31V/j4+GDOnDl47LHHAADnzp3Ds88+ixkzZuC2225r1HNmZmZCp9Nh4cKFeOihh2zbXX0f2Lvvvvtw//33AwB+/vlnLFq0qM7nru33wFUNeT2nT5+OpUuXYsSIEXjsscdw6dIl/Pzzz9iyZQtuvPFGLFu2zBZrzf2tt95Cu3btAAAdO3Z0eO4XX3wR1113HS5evIgVK1bgxRdfRIcOHTB9+nQAwPbt2/Hbb79h4sSJuOaaa3Ds2DFkZGRgyJAh2LdvHwIDA+s8tvrytbrrrrswefJkh/suXLgQ58+fd9j2zjvv4J577oFWq0VVVRWWL1+OcePG4bvvvsOoUaOc5mB//KmpqY3+2RNRCyAQkez16tVLAGD7mjx5slBdXe0QU15eXuN+TzzxhBAYGChUVFTYtg0ePFgYPHiwQ9yLL74oABDOnDlj2wZA0Ov1ttv//e9/hQ4dOggDBgxwuP/HH38sABDefPPNGs9vsVgEQRAEk8kkABAyMzNt+8aPHy/06dNH6Ny5szBlyhTb9szMTAGAMGnSJKFv37627WVlZUKbNm2Ehx56SAAgbN++XRAEQaiqqhI6dOgg9OnTR7h48aIt/rvvvhMACHPnzrVti4uLE4KCgoQ///zTaZ5X6tq1q0Nu9gYPHixERUU53Vef2u77xhtvCAAEk8lk2/bKK68IrVq1Eg4ePOgQ+8ILLwhKpVI4fvy4IAj//IzfeOONGo8bFRXl8Jr99NNPAgBh5cqVteY4ZcoUoWvXrnUex5kzZwRfX19h2LBhgtlstm1///33BQDCxx9/XOM+zt4L9bF/z/7vf/8TvL29heeee84hpiHvA0EQhOrqagGAkJycbNtmfe/Z//xd/T1ITk4WANR4L135HnL19fy///s/AYDw9NNP1/h5OHu/Osvdyvp6//TTT7ZtFRUVgpeXl5CQkGDb5uxvSE5OjgBA+PTTT2vss+dqvgCExMTEGjGjRo2q8X67Mp+qqiqhT58+wh133FHj/osXLxYAOPxuX/m3ril+l4io6UhmOpzBYMCYMWMQEREBhUKBb775psGPIQgCFixYgJ49e8LPzw+dOnWq8UkPkRRlZmZi/fr1yM7OxvTp05Gdne0wOgEAAQEBtu9LS0tx7tw53HbbbSgvL8f+/fsdYqurq3Hu3DmcPXsWOTk5WLVqFTQaje0T1CudPHkS7733HpKSktC6dWuHfV999RXatWuHp556qsb9apvmtHPnTqxcuRLz5s1zmNJn75FHHsH+/fttU2m++uorBAcH484773SI27FjB86cOYOEhAT4+/vbto8aNQq9e/fG//73PwDA2bNnYTAYMG3aNIcRtLryrI/ZbMa5c+dw7tw5VFVVNeox6rNy5UrcdtttCA0NtT3XuXPnMHToUJjNZhgMBof48vJyh7hz587VmBJpZX2f2E9da4gNGzagqqoKzzzzjMPr+Pjjj6NNmza2n727bNu2DePHj8cDDzxQY3TT1feBlfX18vPzc/n56/o96NChAwDgr7/+qvMxXH09v/rqKygUCuj1+hqP0dj3a3FxMc6dO4fjx4/j9ddfh8ViwR133GHbb/83pLq6Gn///Te6d++OkJAQ7Nq1q87Hbop87fM5f/48iouLcdtttznNxZXXsyl/l4jI/SQzHa6srAw33HADpk2bZpt60FAzZ87Ejz/+iAULFqBv374oLCxEYWGhmzMlanliYmJs3z/00ENQq9WYM2cOpk+fjtjYWABAXl4edDod/u///g8lJSUO9y8uLna4/dtvv6F9+/a22z169MA333xT68mKXq9HREQEnnjiiRpz8o8cOYJevXrB29v1P1cvvPACbrvtNowePRpPPvmk05j27dtj1KhR+Pjjj3HjjTfi448/xpQpU2oUTX/++ScAoFevXjUeo3fv3vjll18AAEePHgWARl/H48z+/fttP0cvLy90794der3eYYrW1Tp06BCMRqPD62XvzJkzDrf1er3TE9Erp0YBwLRp02zft27dGmPGjMFbb73lNNaZ2n72vr6+UKvVtv3ucPLkSYwaNQplZWX4+++/a7xXXX0fWFkLvyuLmbrU9XsQExMDhUKB2bNnIyUlxfa4V17n4+rreeTIEURERCAsLMzl/Opz77332r738vKCTqfDAw88YNt28eJFzJs3D5mZmTh58qTD9UJX/g25UlPk+9133yElJQW5ubmorKy0bXf2d8qV17Mpf5eIyP0kUwSNGDECI0aMqHV/ZWUl5syZg88//xxFRUXo06cP5s+fb7sA8Y8//kBGRgb27t1r+08uMjKyOVInanEefPBBzJkzB1u3bkV
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"from typing import Tuple\n",
"import pandas as pd\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"\n",
"# Создание целевого признака\n",
"median_price = df['price'].median()\n",
"df['above_median_price'] = np.where(df['price'] > median_price, 1, 0)\n",
"\n",
"# Разделение на признаки и целевую переменную\n",
"X = df.drop(columns=['id', 'date', 'price', 'above_median_price'])\n",
"y = df['above_median_price']\n",
"\n",
"# Примерная категоризация\n",
"df['price_category'] = pd.cut(df['price'], bins=[0, 300000, 700000, np.inf], labels=[0, 1, 2])\n",
"\n",
"# Выбор признаков и целевых переменных\n",
"X = df.drop(columns=['id', 'date', 'price', 'price_category'])\n",
"\n",
"\n",
"def split_stratified_into_train_val_test(\n",
" df_input,\n",
" stratify_colname=\"y\",\n",
" frac_train=0.6,\n",
" frac_val=0.15,\n",
" frac_test=0.25,\n",
" random_state=None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if frac_train + frac_val + frac_test != 1.0:\n",
" raise ValueError(\n",
" \"fractions %f, %f, %f do not add up to 1.0\"\n",
" % (frac_train, frac_val, frac_test)\n",
" )\n",
" \n",
" if stratify_colname not in df_input.columns:\n",
" raise ValueError(\"%s is not a column in the dataframe\" % (stratify_colname))\n",
" X = df_input # Contains all columns.\n",
" y = df_input[\n",
" [stratify_colname]\n",
" ] # Dataframe of just the column on which to stratify.\n",
" \n",
" # Split original dataframe into train and temp dataframes.\n",
" df_train, df_temp, y_train, y_temp = train_test_split(\n",
" X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state\n",
" )\n",
"\n",
" if frac_val <= 0:\n",
" assert len(df_input) == len(df_train) + len(df_temp)\n",
" return df_train, pd.DataFrame(), df_temp, y_train, pd.DataFrame(), y_temp\n",
" # Split the temp dataframe into val and test dataframes.\n",
" relative_frac_test = frac_test / (frac_val + frac_test)\n",
"\n",
" df_val, df_test, y_val, y_test = train_test_split(\n",
" df_temp,\n",
" y_temp,\n",
" stratify=y_temp,\n",
" test_size=relative_frac_test,\n",
" random_state=random_state,\n",
" )\n",
"\n",
" assert len(df_input) == len(df_train) + len(df_val) + len(df_test)\n",
" return df_train, df_val, df_test, y_train, y_val, y_test\n",
"\n",
"X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n",
" df, stratify_colname=\"above_median_price\", frac_train=0.80, frac_val=0, frac_test=0.20, random_state=42\n",
")\n",
"\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)\n",
"\n",
"\n",
"# Проверка преобразования\n",
"print(df.dtypes)\n",
"\n",
"# Визуализация распределения цен\n",
"plt.figure(figsize=(10, 6))\n",
"sns.histplot(df['price'], bins=50, kde=True)\n",
"plt.title('Распределение цен на недвижимость')\n",
"plt.xlabel('Цена')\n",
"plt.ylabel('Частота')\n",
"plt.show()\n",
"\n",
"# Визуализация зависимости между ценой и количеством спален\n",
"plt.figure(figsize=(10, 6))\n",
"sns.boxplot(x='bedrooms', y='price', data=df)\n",
"plt.title('Зависимость цены от количества спален')\n",
"plt.xlabel('Количество спален')\n",
"plt.ylabel('Цена')\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Построение конвейеров предобработки \n",
"Создадим пайплайн для числовых и категориальных данных. \n",
"\n",
"preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n",
"\n",
"preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n",
"\n",
"features_preprocessing -- трансформер для предобработки признаков\n",
"\n",
"features_engineering -- трансформер для конструирования признаков\n",
"\n",
"drop_columns -- трансформер для удаления колонок\n",
"\n",
"pipeline_end -- основной конвейер предобработки данных и конструирования признаков"
]
},
{
"cell_type": "code",
"execution_count": 150,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.discriminant_analysis import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"pipeline_end = StandardScaler()\n",
"\n",
"\n",
"# Построение конвейеров предобработки\n",
"\n",
"class HouseFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" # Создание новых признаков\n",
" X = X.copy()\n",
" X[\"Living_area_to_Lot_ratio\"] = X[\"sqft_living\"] / X[\"sqft_lot\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" # Добавление имен новых признаков\n",
" new_features = [\"Living_area_to_Lot_ratio\"]\n",
" return np.append(features_in, new_features, axis=0)\n",
"\n",
"\n",
"# Обработка числовых данных. Числовой конвейр: заполнение пропущенных значений медианой и стандартизация\n",
"preprocessing_num_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='median')),\n",
" ('scaler', StandardScaler())\n",
"])\n",
"\n",
"preprocessing_cat_class = Pipeline(steps=[\n",
" ('imputer', SimpleImputer(strategy='most_frequent')),\n",
" ('onehot', OneHotEncoder(handle_unknown='ignore'))\n",
"])\n",
"\n",
"columns_to_drop = [\"date\"]\n",
"numeric_columns = [\"sqft_living\", \"sqft_lot\", \"above_median_price\"]\n",
"cat_columns = []\n",
"\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"prepocessing_num\", preprocessing_num_class, numeric_columns),\n",
" (\"prepocessing_cat\", preprocessing_cat_class, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" ('preprocessing_cat', preprocessing_cat_class, [\"price_category\"]),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"pipeline_end = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"custom_features\", HouseFeatures()),\n",
" (\"drop_columns\", drop_columns),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Демонстрация работы конвейра для предобработки данных при классификации**"
]
},
{
"cell_type": "code",
"execution_count": 151,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>above_median_price</th>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>price_category</th>\n",
" <th>Living_area_to_Lot_ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>20962</th>\n",
" <td>-1.360742</td>\n",
" <td>-0.262132</td>\n",
" <td>-0.994693</td>\n",
" <td>1278000210</td>\n",
" <td>110000.0</td>\n",
" <td>2</td>\n",
" <td>1.00</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1968</td>\n",
" <td>2007</td>\n",
" <td>98001</td>\n",
" <td>47.2655</td>\n",
" <td>-122.244</td>\n",
" <td>828</td>\n",
" <td>5402</td>\n",
" <td>0</td>\n",
" <td>5.191063</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12284</th>\n",
" <td>0.794390</td>\n",
" <td>-0.094121</td>\n",
" <td>1.005335</td>\n",
" <td>2193300390</td>\n",
" <td>624000.0</td>\n",
" <td>4</td>\n",
" <td>3.25</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1130</td>\n",
" <td>1980</td>\n",
" <td>0</td>\n",
" <td>98052</td>\n",
" <td>47.6920</td>\n",
" <td>-122.099</td>\n",
" <td>2110</td>\n",
" <td>11250</td>\n",
" <td>1</td>\n",
" <td>-8.440052</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7343</th>\n",
" <td>0.837884</td>\n",
" <td>-0.272723</td>\n",
" <td>1.005335</td>\n",
" <td>4289900005</td>\n",
" <td>1535000.0</td>\n",
" <td>4</td>\n",
" <td>3.25</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>...</td>\n",
" <td>1030</td>\n",
" <td>1908</td>\n",
" <td>2003</td>\n",
" <td>98122</td>\n",
" <td>47.6147</td>\n",
" <td>-122.285</td>\n",
" <td>2130</td>\n",
" <td>4200</td>\n",
" <td>2</td>\n",
" <td>-3.072292</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14247</th>\n",
" <td>-0.782270</td>\n",
" <td>-0.196986</td>\n",
" <td>-0.994693</td>\n",
" <td>316000145</td>\n",
" <td>235000.0</td>\n",
" <td>4</td>\n",
" <td>1.00</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1941</td>\n",
" <td>0</td>\n",
" <td>98168</td>\n",
" <td>47.5054</td>\n",
" <td>-122.301</td>\n",
" <td>1280</td>\n",
" <td>7175</td>\n",
" <td>0</td>\n",
" <td>3.971201</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16670</th>\n",
" <td>1.011860</td>\n",
" <td>0.024330</td>\n",
" <td>1.005335</td>\n",
" <td>629400480</td>\n",
" <td>775000.0</td>\n",
" <td>4</td>\n",
" <td>2.75</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1996</td>\n",
" <td>0</td>\n",
" <td>98075</td>\n",
" <td>47.5895</td>\n",
" <td>-121.994</td>\n",
" <td>3330</td>\n",
" <td>12333</td>\n",
" <td>2</td>\n",
" <td>41.589045</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>88</th>\n",
" <td>-0.510432</td>\n",
" <td>-0.324180</td>\n",
" <td>-0.994693</td>\n",
" <td>1332700270</td>\n",
" <td>215000.0</td>\n",
" <td>2</td>\n",
" <td>2.25</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1979</td>\n",
" <td>0</td>\n",
" <td>98056</td>\n",
" <td>47.5180</td>\n",
" <td>-122.194</td>\n",
" <td>1950</td>\n",
" <td>2025</td>\n",
" <td>0</td>\n",
" <td>1.574534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15031</th>\n",
" <td>1.044481</td>\n",
" <td>-0.314813</td>\n",
" <td>1.005335</td>\n",
" <td>7129303070</td>\n",
" <td>735000.0</td>\n",
" <td>4</td>\n",
" <td>2.75</td>\n",
" <td>2.0</td>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1966</td>\n",
" <td>0</td>\n",
" <td>98118</td>\n",
" <td>47.5188</td>\n",
" <td>-122.256</td>\n",
" <td>2620</td>\n",
" <td>2433</td>\n",
" <td>2</td>\n",
" <td>-3.317784</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5234</th>\n",
" <td>-0.456065</td>\n",
" <td>-0.136611</td>\n",
" <td>1.005335</td>\n",
" <td>2432000130</td>\n",
" <td>675000.0</td>\n",
" <td>3</td>\n",
" <td>1.75</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1956</td>\n",
" <td>0</td>\n",
" <td>98033</td>\n",
" <td>47.6503</td>\n",
" <td>-122.198</td>\n",
" <td>2090</td>\n",
" <td>9549</td>\n",
" <td>1</td>\n",
" <td>3.338418</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19980</th>\n",
" <td>0.566046</td>\n",
" <td>1.239169</td>\n",
" <td>-0.994693</td>\n",
" <td>774100475</td>\n",
" <td>415000.0</td>\n",
" <td>3</td>\n",
" <td>2.75</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2009</td>\n",
" <td>0</td>\n",
" <td>98014</td>\n",
" <td>47.7185</td>\n",
" <td>-121.405</td>\n",
" <td>1740</td>\n",
" <td>64626</td>\n",
" <td>1</td>\n",
" <td>0.456795</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3671</th>\n",
" <td>0.370323</td>\n",
" <td>4.836825</td>\n",
" <td>1.005335</td>\n",
" <td>8847400115</td>\n",
" <td>590000.0</td>\n",
" <td>3</td>\n",
" <td>2.00</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2005</td>\n",
" <td>0</td>\n",
" <td>98010</td>\n",
" <td>47.3666</td>\n",
" <td>-121.978</td>\n",
" <td>3180</td>\n",
" <td>212137</td>\n",
" <td>1</td>\n",
" <td>0.076563</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>17290 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" sqft_living sqft_lot above_median_price id price \\\n",
"20962 -1.360742 -0.262132 -0.994693 1278000210 110000.0 \n",
"12284 0.794390 -0.094121 1.005335 2193300390 624000.0 \n",
"7343 0.837884 -0.272723 1.005335 4289900005 1535000.0 \n",
"14247 -0.782270 -0.196986 -0.994693 316000145 235000.0 \n",
"16670 1.011860 0.024330 1.005335 629400480 775000.0 \n",
"... ... ... ... ... ... \n",
"88 -0.510432 -0.324180 -0.994693 1332700270 215000.0 \n",
"15031 1.044481 -0.314813 1.005335 7129303070 735000.0 \n",
"5234 -0.456065 -0.136611 1.005335 2432000130 675000.0 \n",
"19980 0.566046 1.239169 -0.994693 774100475 415000.0 \n",
"3671 0.370323 4.836825 1.005335 8847400115 590000.0 \n",
"\n",
" bedrooms bathrooms floors waterfront view ... sqft_basement \\\n",
"20962 2 1.00 1.0 0 0 ... 0 \n",
"12284 4 3.25 1.0 0 0 ... 1130 \n",
"7343 4 3.25 2.0 0 3 ... 1030 \n",
"14247 4 1.00 1.5 0 0 ... 0 \n",
"16670 4 2.75 2.0 0 0 ... 0 \n",
"... ... ... ... ... ... ... ... \n",
"88 2 2.25 2.0 0 0 ... 0 \n",
"15031 4 2.75 2.0 1 4 ... 0 \n",
"5234 3 1.75 1.0 0 0 ... 0 \n",
"19980 3 2.75 1.5 0 0 ... 0 \n",
"3671 3 2.00 1.5 0 0 ... 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"20962 1968 2007 98001 47.2655 -122.244 828 \n",
"12284 1980 0 98052 47.6920 -122.099 2110 \n",
"7343 1908 2003 98122 47.6147 -122.285 2130 \n",
"14247 1941 0 98168 47.5054 -122.301 1280 \n",
"16670 1996 0 98075 47.5895 -121.994 3330 \n",
"... ... ... ... ... ... ... \n",
"88 1979 0 98056 47.5180 -122.194 1950 \n",
"15031 1966 0 98118 47.5188 -122.256 2620 \n",
"5234 1956 0 98033 47.6503 -122.198 2090 \n",
"19980 2009 0 98014 47.7185 -121.405 1740 \n",
"3671 2005 0 98010 47.3666 -121.978 3180 \n",
"\n",
" sqft_lot15 price_category Living_area_to_Lot_ratio \n",
"20962 5402 0 5.191063 \n",
"12284 11250 1 -8.440052 \n",
"7343 4200 2 -3.072292 \n",
"14247 7175 0 3.971201 \n",
"16670 12333 2 41.589045 \n",
"... ... ... ... \n",
"88 2025 0 1.574534 \n",
"15031 2433 2 -3.317784 \n",
"5234 9549 1 3.338418 \n",
"19980 64626 1 0.456795 \n",
"3671 212137 1 0.076563 \n",
"\n",
"[17290 rows x 23 columns]"
]
},
"execution_count": 151,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.fit_transform(X_train)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"preprocessed_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Формирование набора моделей для классификации\n",
"\n",
"logistic -- логистическая регрессия\n",
"\n",
"ridge -- гребневая регрессия\n",
"\n",
"decision_tree -- дерево решений\n",
"\n",
"knn -- k-ближайших соседей\n",
"\n",
"naive_bayes -- наивный Байесовский классификатор\n",
"\n",
"gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n",
"\n",
"random_forest -- метод случайного леса (набор деревьев решений)\n",
"\n",
"mlp -- многослойный персептрон (нейронная сеть)"
]
},
{
"cell_type": "code",
"execution_count": 152,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree, svm\n",
"\n",
"class_models = {\n",
" \"logistic\": {\"model\": linear_model.LogisticRegression(max_iter=150)},\n",
" \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n",
" \"ridge\": {\"model\": linear_model.LogisticRegression(max_iter=150, solver='lbfgs', penalty=\"l2\", class_weight=\"balanced\")},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeClassifier(max_depth=5, min_samples_split=10, random_state=random_state)\n",
" },\n",
"\n",
" \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n",
" \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n",
" \"gradient_boosting\": {\n",
" \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n",
" },\n",
"\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestClassifier(\n",
" max_depth=5, class_weight=\"balanced\", random_state=random_state\n",
" )\n",
" },\n",
"\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPClassifier(\n",
" hidden_layer_sizes=(7,),\n",
" max_iter=200,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Обучение моделей на обучающем наборе данных и оценка на тестовом**"
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"Model: ridge\n",
"Model: decision_tree\n",
"Model: knn\n",
"Model: naive_bayes\n",
"Model: gradient_boosting\n",
"Model: random_forest\n",
"Model: mlp\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" model = class_models[model_name][\"model\"]\n",
"\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n",
"\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n",
" y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n",
"\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"probs\"] = y_test_probs\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n",
" y_train, y_train_predict, zero_division=1\n",
" )\n",
" class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n",
" y_test, y_test_predict, zero_division=1\n",
" )\n",
" class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n",
" y_train, y_train_predict\n",
" )\n",
" class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n",
" y_test, y_test_probs\n",
" )\n",
" class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n",
" class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n",
" class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n",
" y_test, y_test_predict\n",
" )\n",
" class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n",
" y_test, y_test_predict\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Сводная таблица оценок качества для использованных моделей классификации¶\n",
"Матрица неточностей**"
]
},
{
"cell_type": "code",
"execution_count": 154,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0cAAAQ9CAYAAACSpDaqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdeVwU5eMH8M8ssIDcqFyCeKAo3mIZeSeBZB5p+fNKUdTyq+aRZ6aClZiWZx5peX3T1C4rzYNUxJQsDzwISRRvDhUBQYFld35/8GVsAxYWFheYz/v7mtfXnedh5pk15uMzz8wzgiiKIoiIiIiIiGROYewGEBERERERVQXsHBEREREREYGdIyIiIiIiIgDsHBEREREREQFg54iIiIiIiAgAO0dEREREREQA2DkiIiIiIiICwM4RERERERERAHaOiIiIiIiIALBzROW0ZcsWCIKA69evV8r2r1+/DkEQsGXLFoNsLzIyEoIgIDIy0iDbIyIiqilCQ0MhCEKZ6gqCgNDQ0MptEJERsXNENcratWsN1qEiIiIiInkxNXYDiIrj6emJJ0+ewMzMTK+fW7t2LerUqYPg4GCt9V27dsWTJ0+gVCoN2EoiIqLq7/3338fs2bON3QyiKoGdI6qSBEGAhYWFwbanUCgMuj0iIqKaIDs7G1ZWVjA15T8JiQDeVkcGtHbtWrRo0QLm5uZwc3PDhAkTkJ6eXqTemjVr0KhRI1haWuL555/H8ePH0b17d3Tv3l2qU9wzR8nJyRg1ahTc3d1hbm4OV1dX9OvXT3ruqUGDBoiNjcWxY8cgCAIEQZC2WdIzR6dOncIrr7wCBwcHWFlZoXXr1li5cqVhvxgiIqIqoPDZor/++gtDhw6Fg4MDOnfuXOwzR7m5uZg6dSrq1q0LGxsb9O3bF7dv3y52u5GRkejQoQMsLCzQuHFjfP755yU+x/TVV1/B19cXlpaWcHR0xODBg3Hr1q1KOV6i8uBlAjKI0NBQhIWFwd/fH+PHj0d8fDzWrVuHP//8EydOnJBuj1u3bh0mTpyILl26YOrUqbh+/Tr69+8PBwcHuLu769zHwIEDERsbi0mTJqFBgwZITU1FREQEbt68iQYNGmDFihWYNGkSrK2tMXfuXACAs7NziduLiIjAq6++CldXV0yePBkuLi6Ii4vD3r17MXnyZMN9OURERFXIG2+8gSZNmmDRokUQRRGpqalF6owZMwZfffUVhg4dihdffBFHjhxB7969i9Q7d+4cevXqBVdXV4SFhUGtVmPhwoWoW7dukbofffQR5s2bh0GDBmHMmDG4d+8eVq9eja5du+LcuXOwt7evjMMl0o9IVA6bN28WAYiJiYliamqqqFQqxYCAAFGtVkt1PvvsMxGAuGnTJlEURTE3N1esXbu2+Nxzz4kqlUqqt2XLFhGA2K1bN2ldYmKiCEDcvHmzKIqi+PDhQxGAuHTpUp3tatGihdZ2Ch09elQEIB49elQURVHMz88XGzZsKHp6eooPHz7UqqvRaMr+RRAREVUTCxYsEAGIQ4YMKXZ9oZiYGBGA+J///Eer3tChQ0UA4oIFC6R1ffr0EWvVqiXeuXNHWnflyhXR1NRUa5vXr18XTUxMxI8++khrmxcvXhRNTU2LrCcyFt5WRxX266+/Ii8vD1OmTIFC8fQ/qbFjx8LW1hb79u0DAJw+fRoPHjzA2LFjte5tHjZsGBwcHHTuw9LSEkqlEpGRkXj48GGF23zu3DkkJiZiypQpRa5UlXU6UyIiouro7bff1ln+yy+/AADeeecdrfVTpkzR+qxWq/Hrr7+if//+cHNzk9Z7eXkhKChIq+73338PjUaDQYMG4f79+9Li4uKCJk2a4OjRoxU4IiLD4W11VGE3btwAAHh7e2utVyqVaNSokVRe+P9eXl5a9UxNTdGgQQOd+zA3N8fHH3+Md999F87OznjhhRfw6quvYsSIEXBxcdG7zVevXgUAtGzZUu+fJSIiqs4aNmyos/zGjRtQKBRo3Lix1vp/53xqaiqePHlSJNeBoll/5coViKKIJk2aFLtPfWenJaos7BxRtTFlyhT06dMHe/bswcGDBzFv3jyEh4fjyJEjaNeunbGbR0REVC1YWlo+831qNBoIgoD9+/fDxMSkSLm1tfUzbxNRcXhbHVWYp6cnACA+Pl5rfV5eHhITE6Xywv9PSEjQqpefny/NOFeaxo0b491338WhQ4dw6dIl5OXl4dNPP5XKy3pLXOHVsEuXLpWpPhERkVx4enpCo9FId1kU+nfOOzk5wcLCokiuA0WzvnHjxhBFEQ0bNoS/v3+R5YUXXjD8gRCVAztHVGH+/v5QKpVYtWoVRFGU1n/55ZfIyMiQZrfp0KEDateujY0bNyI/P1+qt3379lKfI3r8+DFycnK01jVu3Bg2NjbIzc2V1llZWRU7ffi/tW/fHg0bNsSKFSuK1P/nMRAREclN4fNCq1at0lq/YsUKrc8mJibw9/fHnj17cPfuXWl9QkIC9u/fr1V3wIABMDExQVhYWJGcFUURDx48MOAREJUfb6ujCqtbty7mzJmDsLAw9OrVC3379kV8fDzWrl2L5557DsOHDwdQ8AxSaGgoJk2ahJdeegmDBg3C9evXsWXLFjRu3FjnqM/ff/+Nnj17YtCgQfDx8YGpqSl++OEHpKSkYPDgwVI9X19frFu3Dh9++CG8vLzg5OSEl156qcj2FAoF1q1bhz59+qBt27YYNWoUXF1dcfnyZcTGxuLgwYOG/6KIiIiqgbZt22LIkCFYu3YtMjIy8OKLL+Lw4cPFjhCFhobi0KFD6NSpE8aPHw+1Wo3PPvsMLVu2RExMjFSvcePG+PDDDzFnzhzpNR42NjZITEzEDz/8gHHjxmH69OnP8CiJisfOERlEaGgo6tati88++wxTp06Fo6Mjxo0bh0WLFmk9ZDlx4kSIoohPP/0U06dPR5s2bfDTTz/hnXfegYWFRYnb9/DwwJAhQ3D48GH897//hampKZo1a4bdu3dj4MCBUr358+fjxo0bWLJkCR49eoRu3boV2zkCgMDAQBw9ehRhYWH49NNPodFo0LhxY4wdO9ZwXwwREVE1tGnTJtStWxfbt2/Hnj178NJLL2Hfvn3w8PDQqufr64v9+/dj+vTpmDdvHjw8PLBw4ULExcXh8uXLWnVnz56Npk2bYvny5QgLCwNQkO8BAQHo27fvMzs2Il0EkfcQkZFpNBrUrVsXAwYMwMaNG43dHCIiIqqg/v37IzY2FleuXDF2U4j0wmeO6JnKyckpcq/xtm3bkJaWhu7duxunUURERFRuT5480fp85coV/PLLL8x1qpY4ckTPVGRkJKZOnYo33ngDtWvXxtmzZ/Hll1+iefPmOHPmDJRKpbGbSERERHpwdXVFcHCw9G7DdevWITc3F+fOnSvxvUZEVRWfOaJnqkGDBvDw8MCqVauQlpYGR0dHjBgxAosXL2bHiIiIqBrq1asXvv76ayQnJ8Pc3Bx+fn5YtGgRO0ZULXHkiIiIiIiICHzmiIiIiIiICAA7R0RERERERAD4zFGZaDQa3L17FzY2NjpfVEpUE4miiEePHsHNzQ0KhWGvp+Tk5CAvL6/UekqlUud7sIhIfpjNJGfM5srDzlEZ3L17t8hLz4jk5tatW3B3dzfY9nJyctDQ0xrJqepS67q4uCAxMbFGnoSJqHyYzUTM5srAzlEZ2NjYAABunG0AW2veiWgMrzVtZewmyFY+VPgNv0i/B4aSl5eH5FQ1Ek57wNam5N+rzEcaeHW4hby8vBp3Aiai8mM2Gx+z2XiYzZWHnaMyKByut7VW6PwPhSqPqWBm7CbI1//ms6ys21asbQRY25S8bQ14uwwRFcVsNj5msxExmysNO0dEZFQqUQ2VjjcKqETNM2wNERERyTmb2TkiIqPSQIQGJZ+AdZURERGR4ck5m9k5IiKj0kCEWqYnYCIioqpIztnMzhERGZVK1ECl4xxbk4fuiYiIqiI5ZzM7R0RkVJr/LbrKiYiI6NmRczazc0RERqUuZeheVxkREREZnpyzmZ0jIjIqlYhShu6fXVuIiIhI3tnMzhE
"text/plain": [
"<Figure size 1200x1000 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from sklearn.metrics import ConfusionMatrixDisplay\n",
"import matplotlib.pyplot as plt\n",
"\n",
"_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n",
"for index, key in enumerate(class_models.keys()):\n",
" c_matrix = class_models[key][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
" disp.ax_.set_title(key)\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Значение 2173 в желтом квадрате представляет собой количество объектов, относимых к классу \"Less\", которые модель правильно классифицировала. Это свидетельствует о высоком уровне точности в идентификации этого класса. Значение 2150 в жёлтом нижнем правом квадрате указывает на количество правильно классифицированных объектов класса \"More\". Хотя это также является положительным результатом, мы можем заметить, что он местами ниже, чем для класса \"Less\", а местами и выше.\n",
"\n",
"Точность, полнота, верность (аккуратность), F-мера"
]
},
{
"cell_type": "code",
"execution_count": 155,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_c780a_row0_col0, #T_c780a_row0_col1, #T_c780a_row0_col2, #T_c780a_row0_col3, #T_c780a_row1_col0, #T_c780a_row1_col1, #T_c780a_row1_col2, #T_c780a_row1_col3, #T_c780a_row2_col0, #T_c780a_row2_col1, #T_c780a_row2_col2, #T_c780a_row2_col3, #T_c780a_row3_col0, #T_c780a_row3_col1, #T_c780a_row3_col2, #T_c780a_row3_col3, #T_c780a_row4_col0, #T_c780a_row4_col1, #T_c780a_row4_col2, #T_c780a_row4_col3, #T_c780a_row5_col0, #T_c780a_row5_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_c780a_row0_col4, #T_c780a_row0_col5, #T_c780a_row0_col6, #T_c780a_row0_col7, #T_c780a_row1_col4, #T_c780a_row1_col5, #T_c780a_row1_col6, #T_c780a_row1_col7, #T_c780a_row2_col4, #T_c780a_row2_col5, #T_c780a_row2_col6, #T_c780a_row2_col7, #T_c780a_row3_col4, #T_c780a_row3_col5, #T_c780a_row3_col6, #T_c780a_row3_col7, #T_c780a_row4_col4, #T_c780a_row4_col5, #T_c780a_row4_col6, #T_c780a_row4_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row5_col2 {\n",
" background-color: #6ccd5a;\n",
" color: #000000;\n",
"}\n",
"#T_c780a_row5_col3 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
"}\n",
"#T_c780a_row5_col4 {\n",
" background-color: #c43e7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row5_col5 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row5_col6, #T_c780a_row5_col7 {\n",
" background-color: #ce4b75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col0 {\n",
" background-color: #40bd72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col1 {\n",
" background-color: #38b977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col2 {\n",
" background-color: #7fd34e;\n",
" color: #000000;\n",
"}\n",
"#T_c780a_row6_col3 {\n",
" background-color: #75d054;\n",
" color: #000000;\n",
"}\n",
"#T_c780a_row6_col4 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col5 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col6 {\n",
" background-color: #cc4977;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row6_col7 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row7_col0, #T_c780a_row7_col1, #T_c780a_row7_col2, #T_c780a_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_c780a_row7_col4, #T_c780a_row7_col5, #T_c780a_row7_col6, #T_c780a_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_c780a\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_c780a_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_c780a_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_c780a_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_c780a_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_c780a_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_c780a_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_c780a_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_c780a_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_c780a_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row0_col2\" class=\"data row0 col2\" >0.999767</td>\n",
" <td id=\"T_c780a_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_c780a_row0_col4\" class=\"data row0 col4\" >0.999884</td>\n",
" <td id=\"T_c780a_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_c780a_row0_col6\" class=\"data row0 col6\" >0.999884</td>\n",
" <td id=\"T_c780a_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_c780a_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row1_col2\" class=\"data row1 col2\" >0.999651</td>\n",
" <td id=\"T_c780a_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_c780a_row1_col4\" class=\"data row1 col4\" >0.999826</td>\n",
" <td id=\"T_c780a_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_c780a_row1_col6\" class=\"data row1 col6\" >0.999826</td>\n",
" <td id=\"T_c780a_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_c780a_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col5\" class=\"data row2 col5\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col6\" class=\"data row2 col6\" >1.000000</td>\n",
" <td id=\"T_c780a_row2_col7\" class=\"data row2 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_c780a_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col5\" class=\"data row3 col5\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col6\" class=\"data row3 col6\" >1.000000</td>\n",
" <td id=\"T_c780a_row3_col7\" class=\"data row3 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_c780a_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col5\" class=\"data row4 col5\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col6\" class=\"data row4 col6\" >1.000000</td>\n",
" <td id=\"T_c780a_row4_col7\" class=\"data row4 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
" <td id=\"T_c780a_row5_col0\" class=\"data row5 col0\" >1.000000</td>\n",
" <td id=\"T_c780a_row5_col1\" class=\"data row5 col1\" >1.000000</td>\n",
" <td id=\"T_c780a_row5_col2\" class=\"data row5 col2\" >0.786719</td>\n",
" <td id=\"T_c780a_row5_col3\" class=\"data row5 col3\" >0.793953</td>\n",
" <td id=\"T_c780a_row5_col4\" class=\"data row5 col4\" >0.893927</td>\n",
" <td id=\"T_c780a_row5_col5\" class=\"data row5 col5\" >0.897525</td>\n",
" <td id=\"T_c780a_row5_col6\" class=\"data row5 col6\" >0.880630</td>\n",
" <td id=\"T_c780a_row5_col7\" class=\"data row5 col7\" >0.885144</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_c780a_row6_col0\" class=\"data row6 col0\" >0.872486</td>\n",
" <td id=\"T_c780a_row6_col1\" class=\"data row6 col1\" >0.827473</td>\n",
" <td id=\"T_c780a_row6_col2\" class=\"data row6 col2\" >0.857774</td>\n",
" <td id=\"T_c780a_row6_col3\" class=\"data row6 col3\" >0.820930</td>\n",
" <td id=\"T_c780a_row6_col4\" class=\"data row6 col4\" >0.866917</td>\n",
" <td id=\"T_c780a_row6_col5\" class=\"data row6 col5\" >0.825815</td>\n",
" <td id=\"T_c780a_row6_col6\" class=\"data row6 col6\" >0.865068</td>\n",
" <td id=\"T_c780a_row6_col7\" class=\"data row6 col7\" >0.824189</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_c780a_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_c780a_row7_col0\" class=\"data row7 col0\" >0.687500</td>\n",
" <td id=\"T_c780a_row7_col1\" class=\"data row7 col1\" >0.615385</td>\n",
" <td id=\"T_c780a_row7_col2\" class=\"data row7 col2\" >0.002558</td>\n",
" <td id=\"T_c780a_row7_col3\" class=\"data row7 col3\" >0.003721</td>\n",
" <td id=\"T_c780a_row7_col4\" class=\"data row7 col4\" >0.503355</td>\n",
" <td id=\"T_c780a_row7_col5\" class=\"data row7 col5\" >0.503354</td>\n",
" <td id=\"T_c780a_row7_col6\" class=\"data row7 col6\" >0.005098</td>\n",
" <td id=\"T_c780a_row7_col7\" class=\"data row7 col7\" >0.007397</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x234a486d7f0>"
]
},
"execution_count": 155,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(\n",
" by=\"Accuracy_test\", ascending=False\n",
").style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Действительно, если модели, включая логистическую регрессию (есть исключения), ридж-регрессию (есть исключения), дерево решений, случайный лес и градиентный бустинг, показывают 100% точность на обучающей выборке, это может свидетельствовать о переобучении. Переобучение (overfitting) происходит, когда модель слишком хорошо подстраивается под обучающие данные, включая шум и случайные вариации, и начинает плохо работать на новых данных (например, на тестовой выборке). \n",
"\n",
"ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса"
]
},
{
"cell_type": "code",
"execution_count": 156,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_9fdc4_row0_col0, #T_9fdc4_row0_col1, #T_9fdc4_row1_col0, #T_9fdc4_row1_col1, #T_9fdc4_row2_col0, #T_9fdc4_row2_col1, #T_9fdc4_row3_col0, #T_9fdc4_row3_col1, #T_9fdc4_row4_col0, #T_9fdc4_row4_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_9fdc4_row0_col2, #T_9fdc4_row0_col3, #T_9fdc4_row0_col4, #T_9fdc4_row1_col2, #T_9fdc4_row1_col3, #T_9fdc4_row1_col4, #T_9fdc4_row2_col2, #T_9fdc4_row2_col3, #T_9fdc4_row2_col4, #T_9fdc4_row3_col2, #T_9fdc4_row3_col3, #T_9fdc4_row3_col4, #T_9fdc4_row4_col2, #T_9fdc4_row4_col3, #T_9fdc4_row4_col4, #T_9fdc4_row5_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row5_col0 {\n",
" background-color: #6ece58;\n",
" color: #000000;\n",
"}\n",
"#T_9fdc4_row5_col1 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_9fdc4_row5_col3 {\n",
" background-color: #c5407e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row5_col4 {\n",
" background-color: #c7427c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row6_col0 {\n",
" background-color: #4cc26c;\n",
" color: #000000;\n",
"}\n",
"#T_9fdc4_row6_col1 {\n",
" background-color: #75d054;\n",
" color: #000000;\n",
"}\n",
"#T_9fdc4_row6_col2 {\n",
" background-color: #c8437b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row6_col3, #T_9fdc4_row6_col4 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row7_col0, #T_9fdc4_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_9fdc4_row7_col2, #T_9fdc4_row7_col3, #T_9fdc4_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_9fdc4\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_9fdc4_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_9fdc4_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_9fdc4_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_9fdc4_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_9fdc4_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row0\" class=\"row_heading level0 row0\" >logistic</th>\n",
" <td id=\"T_9fdc4_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row1\" class=\"row_heading level0 row1\" >ridge</th>\n",
" <td id=\"T_9fdc4_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_9fdc4_row2_col0\" class=\"data row2 col0\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row2_col1\" class=\"data row2 col1\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row2_col3\" class=\"data row2 col3\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row2_col4\" class=\"data row2 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row3\" class=\"row_heading level0 row3\" >gradient_boosting</th>\n",
" <td id=\"T_9fdc4_row3_col0\" class=\"data row3 col0\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row3_col1\" class=\"data row3 col1\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row3_col4\" class=\"data row3 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row4\" class=\"row_heading level0 row4\" >random_forest</th>\n",
" <td id=\"T_9fdc4_row4_col0\" class=\"data row4 col0\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row4_col1\" class=\"data row4 col1\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row4_col2\" class=\"data row4 col2\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row4_col3\" class=\"data row4 col3\" >1.000000</td>\n",
" <td id=\"T_9fdc4_row4_col4\" class=\"data row4 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row5\" class=\"row_heading level0 row5\" >naive_bayes</th>\n",
" <td id=\"T_9fdc4_row5_col0\" class=\"data row5 col0\" >0.897525</td>\n",
" <td id=\"T_9fdc4_row5_col1\" class=\"data row5 col1\" >0.885144</td>\n",
" <td id=\"T_9fdc4_row5_col2\" class=\"data row5 col2\" >0.999566</td>\n",
" <td id=\"T_9fdc4_row5_col3\" class=\"data row5 col3\" >0.794820</td>\n",
" <td id=\"T_9fdc4_row5_col4\" class=\"data row5 col4\" >0.812098</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_9fdc4_row6_col0\" class=\"data row6 col0\" >0.825815</td>\n",
" <td id=\"T_9fdc4_row6_col1\" class=\"data row6 col1\" >0.824189</td>\n",
" <td id=\"T_9fdc4_row6_col2\" class=\"data row6 col2\" >0.910823</td>\n",
" <td id=\"T_9fdc4_row6_col3\" class=\"data row6 col3\" >0.651606</td>\n",
" <td id=\"T_9fdc4_row6_col4\" class=\"data row6 col4\" >0.651627</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_9fdc4_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_9fdc4_row7_col0\" class=\"data row7 col0\" >0.503354</td>\n",
" <td id=\"T_9fdc4_row7_col1\" class=\"data row7 col1\" >0.007397</td>\n",
" <td id=\"T_9fdc4_row7_col2\" class=\"data row7 col2\" >0.497071</td>\n",
" <td id=\"T_9fdc4_row7_col3\" class=\"data row7 col3\" >0.001427</td>\n",
" <td id=\"T_9fdc4_row7_col4\" class=\"data row7 col4\" >0.012966</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2349e152690>"
]
},
"execution_count": 156,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"]\n",
"class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 157,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'logistic'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n",
"\n",
"display(best_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Вывод данных с ошибкой предсказания для оценки**"
]
},
{
"cell_type": "code",
"execution_count": 158,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Error items count: 0'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>Predicted</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"<p>0 rows × 24 columns</p>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: [id, Predicted, date, price, bedrooms, bathrooms, sqft_living, sqft_lot, floors, waterfront, view, condition, grade, sqft_above, sqft_basement, yr_built, yr_renovated, zipcode, lat, long, sqft_living15, sqft_lot15, above_median_price, price_category]\n",
"Index: []\n",
"\n",
"[0 rows x 24 columns]"
]
},
"execution_count": 158,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessing_result = pipeline_end.transform(X_test)\n",
"preprocessed_df = pd.DataFrame(\n",
" preprocessing_result,\n",
" columns=pipeline_end.get_feature_names_out(),\n",
")\n",
"\n",
"y_pred = class_models[best_model][\"preds\"]\n",
"\n",
"error_index = y_test[y_test[\"above_median_price\"] != y_pred].index.tolist()\n",
"display(f\"Error items count: {len(error_index)}\")\n",
"\n",
"error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n",
"error_df = X_test.loc[error_index].copy()\n",
"error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n",
"error_df.sort_index()"
]
},
{
"cell_type": "code",
"execution_count": 159,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>1124000050</td>\n",
" <td>20140729T000000</td>\n",
" <td>461000.0</td>\n",
" <td>4</td>\n",
" <td>1.0</td>\n",
" <td>1260</td>\n",
" <td>8505</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1951</td>\n",
" <td>0</td>\n",
" <td>98177</td>\n",
" <td>47.7181</td>\n",
" <td>-122.371</td>\n",
" <td>1480</td>\n",
" <td>8100</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms sqft_living \\\n",
"6863 1124000050 20140729T000000 461000.0 4 1.0 1260 \n",
"\n",
" sqft_lot floors waterfront view ... sqft_basement yr_built yr_renovated \\\n",
"6863 8505 1.5 0 0 ... 0 1951 0 \n",
"\n",
" zipcode lat long sqft_living15 sqft_lot15 above_median_price \\\n",
"6863 98177 47.7181 -122.371 1480 8100 1 \n",
"\n",
" price_category \n",
"6863 1 \n",
"\n",
"[1 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>above_median_price</th>\n",
" <th>id</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>price_category</th>\n",
" <th>Living_area_to_Lot_ratio</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6863</th>\n",
" <td>-0.891006</td>\n",
" <td>-0.162689</td>\n",
" <td>1.005335</td>\n",
" <td>1.124000e+09</td>\n",
" <td>461000.0</td>\n",
" <td>4.0</td>\n",
" <td>1.0</td>\n",
" <td>1.5</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>...</td>\n",
" <td>0.0</td>\n",
" <td>1951.0</td>\n",
" <td>0.0</td>\n",
" <td>98177.0</td>\n",
" <td>47.7181</td>\n",
" <td>-122.371</td>\n",
" <td>1480.0</td>\n",
" <td>8100.0</td>\n",
" <td>1.0</td>\n",
" <td>5.476729</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" sqft_living sqft_lot above_median_price id price \\\n",
"6863 -0.891006 -0.162689 1.005335 1.124000e+09 461000.0 \n",
"\n",
" bedrooms bathrooms floors waterfront view ... sqft_basement \\\n",
"6863 4.0 1.0 1.5 0.0 0.0 ... 0.0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"6863 1951.0 0.0 98177.0 47.7181 -122.371 1480.0 \n",
"\n",
" sqft_lot15 price_category Living_area_to_Lot_ratio \n",
"6863 8100.0 1.0 5.476729 \n",
"\n",
"[1 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'predicted: 1 (proba: [0. 1.])'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'real: 1'"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = class_models[best_model][\"pipeline\"]\n",
"\n",
"example_id = 6863\n",
"test = pd.DataFrame(X_test.loc[example_id, :]).T\n",
"test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n",
"display(test)\n",
"display(test_preprocessed)\n",
"result_proba = model.predict_proba(test)[0]\n",
"result = model.predict(test)[0]\n",
"real = int(y_test.loc[example_id].values[0])\n",
"display(f\"predicted: {result} (proba: {result_proba})\")\n",
"display(f\"real: {real}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Подбор гиперпараметров методом поиска по сетке**"
]
},
{
"cell_type": "code",
"execution_count": 160,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\MII\\laboratory\\mai\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'gini',\n",
" 'model__max_depth': 5,\n",
" 'model__max_features': 'sqrt',\n",
" 'model__n_estimators': 10}"
]
},
"execution_count": 160,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"optimized_model_type = \"random_forest\"\n",
"\n",
"random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n",
"\n",
"param_grid = {\n",
" \"model__n_estimators\": [10, 50, 100],\n",
" \"model__max_features\": [\"sqrt\", \"log2\"],\n",
" \"model__max_depth\": [5, 7, 10],\n",
" \"model__criterion\": [\"gini\", \"entropy\"],\n",
"}\n",
"\n",
"gs_optomizer = GridSearchCV(\n",
" estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n",
")\n",
"gs_optomizer.fit(X_train, y_train.values.ravel())\n",
"gs_optomizer.best_params_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Обучение модели с новыми гиперпараметрами"
]
},
{
"cell_type": "code",
"execution_count": 161,
"metadata": {},
"outputs": [],
"source": [
"optimized_model = ensemble.RandomForestClassifier(\n",
" random_state=random_state,\n",
" criterion=\"gini\",\n",
" max_depth=5,\n",
" max_features=\"log2\",\n",
" n_estimators=10,\n",
")\n",
"\n",
"result = {}\n",
"\n",
"result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n",
"result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n",
"result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n",
"result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n",
"\n",
"result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n",
"result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n",
"result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n",
"result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n",
"result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n",
"result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n",
"result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n",
"result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n",
"result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n",
"result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n",
"result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n",
"result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Формирование данных для оценки старой и новой версии модели**"
]
},
{
"cell_type": "code",
"execution_count": 162,
"metadata": {},
"outputs": [],
"source": [
"optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=class_models[optimized_model_type]\n",
")\n",
"optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n",
" data=result\n",
")\n",
"optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n",
"optimized_metrics = optimized_metrics.set_index(\"Name\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Оценка параметров старой и новой модели**"
]
},
{
"cell_type": "code",
"execution_count": 163,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_1e498_row0_col0, #T_1e498_row0_col1, #T_1e498_row0_col2, #T_1e498_row0_col3, #T_1e498_row1_col0, #T_1e498_row1_col1, #T_1e498_row1_col2, #T_1e498_row1_col3 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_1e498_row0_col4, #T_1e498_row0_col5, #T_1e498_row0_col6, #T_1e498_row0_col7, #T_1e498_row1_col4, #T_1e498_row1_col5, #T_1e498_row1_col6, #T_1e498_row1_col7 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_1e498\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_1e498_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_1e498_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_1e498_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_1e498_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_1e498_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_1e498_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_1e498_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_1e498_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" <th class=\"blank col5\" >&nbsp;</th>\n",
" <th class=\"blank col6\" >&nbsp;</th>\n",
" <th class=\"blank col7\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_1e498_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_1e498_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col5\" class=\"data row0 col5\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_1e498_row0_col7\" class=\"data row0 col7\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_1e498_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_1e498_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col5\" class=\"data row1 col5\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col6\" class=\"data row1 col6\" >1.000000</td>\n",
" <td id=\"T_1e498_row1_col7\" class=\"data row1 col7\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x234c7e2ab40>"
]
},
"execution_count": 163,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" \"Accuracy_train\",\n",
" \"Accuracy_test\",\n",
" \"F1_train\",\n",
" \"F1_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Precision_train\",\n",
" \"Precision_test\",\n",
" \"Recall_train\",\n",
" \"Recall_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Как для обучающей (Precision_train), так и для тестовой (Precision_test) выборки обе модели достигли идеальных значений 1.000000. Это указывает на то, что модели очень точно классифицируют положительные образцы, не пропуская их."
]
},
{
"cell_type": "code",
"execution_count": 164,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_1a980_row0_col0, #T_1a980_row0_col1, #T_1a980_row1_col0, #T_1a980_row1_col1 {\n",
" background-color: #440154;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_1a980_row0_col2, #T_1a980_row0_col3, #T_1a980_row0_col4, #T_1a980_row1_col2, #T_1a980_row1_col3, #T_1a980_row1_col4 {\n",
" background-color: #0d0887;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_1a980\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_1a980_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_1a980_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_1a980_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_1a980_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_1a980_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" <tr>\n",
" <th class=\"index_name level0\" >Name</th>\n",
" <th class=\"blank col0\" >&nbsp;</th>\n",
" <th class=\"blank col1\" >&nbsp;</th>\n",
" <th class=\"blank col2\" >&nbsp;</th>\n",
" <th class=\"blank col3\" >&nbsp;</th>\n",
" <th class=\"blank col4\" >&nbsp;</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_1a980_level0_row0\" class=\"row_heading level0 row0\" >Old</th>\n",
" <td id=\"T_1a980_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_1a980_row0_col1\" class=\"data row0 col1\" >1.000000</td>\n",
" <td id=\"T_1a980_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_1a980_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_1a980_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_1a980_level0_row1\" class=\"row_heading level0 row1\" >New</th>\n",
" <td id=\"T_1a980_row1_col0\" class=\"data row1 col0\" >1.000000</td>\n",
" <td id=\"T_1a980_row1_col1\" class=\"data row1 col1\" >1.000000</td>\n",
" <td id=\"T_1a980_row1_col2\" class=\"data row1 col2\" >1.000000</td>\n",
" <td id=\"T_1a980_row1_col3\" class=\"data row1 col3\" >1.000000</td>\n",
" <td id=\"T_1a980_row1_col4\" class=\"data row1 col4\" >1.000000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x234a466bbc0>"
]
},
"execution_count": 164,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimized_metrics[\n",
" [\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" \"ROC_AUC_test\",\n",
" \"Cohen_kappa_test\",\n",
" \"MCC_test\",\n",
" ]\n",
"].style.background_gradient(\n",
" cmap=\"plasma\",\n",
" low=0.3,\n",
" high=1,\n",
" subset=[\n",
" \"ROC_AUC_test\",\n",
" \"MCC_test\",\n",
" \"Cohen_kappa_test\",\n",
" ],\n",
").background_gradient(\n",
" cmap=\"viridis\",\n",
" low=1,\n",
" high=0.3,\n",
" subset=[\n",
" \"Accuracy_test\",\n",
" \"F1_test\",\n",
" ],\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Оба варианта модели продемонстрировали безупречную точность классификации, достигнув значения 1.000000. Это свидетельствует о том, что модели точно классифицировали все тестовые примеры, не допустив никаких ошибок в предсказаниях."
]
},
{
"cell_type": "code",
"execution_count": 165,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA2kAAAGsCAYAAABHMu+IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABTp0lEQVR4nO3deXwU9f3H8fcmkAPIJgRIQiSES4FouCumKIdiAlKUgvUnN4KgNmgFOcQqBmkJxWoVpeDFYYWC9aCKigIaEIhW0IAiRC4BhQQUSAhIrt3fH5jVJRyzZJKd3byej8c8fuzM7Dff4Ufz9jvfz3zH5nQ6nQIAAAAAWEKAtzsAAAAAAPgFgzQAAAAAsBAGaQAAAABgIQzSAAAAAMBCGKQBAAAAgIUwSAMAAAAAC2GQBgAAAAAWwiANAAAAACykhrc7AAA4t9OnT6uoqMi09oKCghQSEmJaewAAeIJcM45BGgBY0OnTp9U0vo5yDpea1mZMTIz27t3rt4EGALAucs0zDNIAwIKKioqUc7hUezfHyx5W8cr0/BMONe24T0VFRX4ZZgAAayPXPMMgDQAszB4WYEqYAQBgBeSaMQzSAMDCSp0OlTrNaQcAAG8j14xhkAYAFuaQUw5VPM3MaAMAgIoi14xhrhEAAAAALISZNACwMIccMqOgw5xWAACoGHLNGAZpAGBhpU6nSp0VL+kwow0AACqKXDOGckcAAAAAsBBm0gDAwnjAGgDgT8g1YxikAYCFOeRUKWEGAPAT5JoxlDsCAAAAgIUwkwYAFkZZCADAn5BrxjCTBgAAAAAWwkwaAFgYSxUDAPwJuWYMgzQAsDDHz5sZ7QAA4G3kmjGUOwIAAACAhTCTBgAWVmrSUsVmtAEAQEWRa8YwSAMACyt1ntnMaAcAAG8j14yh3BEAAAAALISZNACwMB6wBgD4E3LNGAZpAGBhDtlUKpsp7QAA4G3kmjGUOwIAAACAhTCTBgAW5nCe2cxoBwAAbyPXjGEmDQAAAAAshJk0ALCwUpNq981oAwCAiiLXjGGQBgAWRpgBAPwJuWYM5Y4AgHLS09P1m9/8RmFhYYqKilK/fv2UnZ3tds7p06eVmpqqevXqqU6dOhowYIByc3Pdztm/f7/69OmjWrVqKSoqShMnTlRJSYnbORkZGerQoYOCg4PVokULLVy4sLIvDwBQzfharjFIAwALczhtpm2eWLt2rVJTU/XJJ59o1apVKi4uVnJysk6ePOk6Z9y4cXr77bf1n//8R2vXrtXBgwfVv39/1/HS0lL16dNHRUVF2rhxoxYtWqSFCxdq6tSprnP27t2rPn36qEePHsrKytL999+vO++8U++//37F//IAAJZDrhljczqdfr42CgD4nvz8fIWHh2vtV5epTljF76cVnHCo21XfKy8vT3a73ePvHzlyRFFRUVq7dq26du2qvLw8NWjQQEuWLNGtt94qSdqxY4dat26tzMxMXXPNNXrvvff0u9/9TgcPHlR0dLQkad68eZo8ebKOHDmioKAgTZ48We+8846++uor18+6/fbbdfz4ca1cubLC1w0AsAZyzbNcYyYNAKqR/Px8t62wsNDQ9/Ly8iRJkZGRkqTNmzeruLhYPXv2dJ3TqlUrNW7cWJmZmZKkzMxMJSYmuoJMklJSUpSfn69t27a5zvl1G2XnlLUBAMCF+GuuMUgDAAsrVYBpmyTFxcUpPDzctaWnp1+0Dw6HQ/fff7+6dOmiq666SpKUk5OjoKAgRUREuJ0bHR2tnJwc1zm/DrKy42XHLnROfn6+fvrpJ8//wgAAlkauGcPqjgBgYc5LqLs/XzuSdODAAbeykODg4It+NzU1VV999ZXWr19f4X4AAKo3cs0YZtIAoBqx2+1u28XCbOzYsVqxYoU++ugjNWrUyLU/JiZGRUVFOn78uNv5ubm5iomJcZ1z9qpYZZ8vdo7dbldoaOglXSMAoPrw11xjkAYAFlb2PhkzNk84nU6NHTtWb775pj788EM1bdrU7XjHjh1Vs2ZNrVmzxrUvOztb+/fvV1JSkiQpKSlJX375pQ4fPuw6Z9WqVbLb7UpISHCd8+s2ys4pawMA4F/INWModwQACyt1BqjUWfH7aaUeruObmpqqJUuW6L///a/CwsJctfbh4eEKDQ1VeHi4Ro0apfHjxysyMlJ2u1333nuvkpKSdM0110iSkpOTlZCQoKFDh2rWrFnKycnRww8/rNTUVNedzrvvvlvPPvusJk2apJEjR+rDDz/Uq6++qnfeeafC1wwAsB5yzRiW4AcACypbqvi9rU1V24Slik+ecKh3m72Glyq22c59h3LBggUaMWKEpDMv/XzggQf073//W4WFhUpJSdE///lPV8mHJO3bt0/33HOPMjIyVLt2bQ0fPlwzZ85UjRq/3CPMyMjQuHHj9PXXX6tRo0Z65JFHXD8DAOAfyLURHl0fgzQAsKCyMHtnazPVDguscHsnT5SqT5s9l/w+GQAAKoJc8wzPpAEAAACAhfBMGgBY2KU8HH2+dgAA8DZyzRgGaQBgYeY9YE1lOwDA+8g1Yyh3BAAAAAALYSYNACzMIZscJpR0mNEGAAAVRa4ZwyANACzMoQCVmlD04JB/l4UAAHwDuWYM5Y4AAAAAYCHMpAGAhfGANQDAn5BrxjBIAwALcyhADspCAAB+glwzhnJHAAAAALAQZtIAwMJKnTaVOk146acJbQAAUFHkmjHMpAEAAACAhTCTBgAWVmrSUsWlfl67DwDwDeSaMQzSAMDCHM4AOUxYBcvh56tgAQB8A7lmDOWOAAAAAGAhzKQBgIVRFgIA8CfkmjEM0gDAwhwyZwUrR8W7AgBAhZFrxlDuCAAAAAAWwkwaAFiYQwFymHA/zYw2AACoKHLNGAZpAGBhpc4AlZqwCpYZbQAAUFHkmjH+fXUAAAAA4GOYSQMAC3PIJofMeMC64m0AAFBR5JoxDNIAwMIoCwEA+BNyzRj/vjoAAAAA8DHMpAGAhZn30k/uyQEAvI9cM8a/rw4AAAAAfAwzaQY4HA4dPHhQYWFhstn8+yFFABXndDp14sQJxcbGKiCgYvfCHE6bHE4THrA2oQ34D3INgCfItarHIM2AgwcPKi4uztvdAOBjDhw4oEaNGlWoDYdJZSH+/tJPeIZcA3ApyLWqwyDNgLCwMEnSvs+byF7Hv/9BwHO/vyLR212AxZSoWOv1rut3B2A15BouhFzD2ci1qscgzYCyUhB7nQDZwwgzuKthq+ntLsBqnGf+jxllZA5ngBwmLDNsRhvwH+QaLoRcQznkWpVjkAYAFlYqm0pNeGGnGW0AAFBR5Jox/j0EBQAAAAAfw0waAFgYZSEAAH9CrhnDIA0ALKxU5pR0lFa8KwAAVBi5Zox/D0EBAAAAwMcwSAMACysrCzFj88S6devUt29fxcbGymazafny5W7HbTbbObfHH3/cdU6TJk3KHZ85c6ZbO1u3btV1112nkJAQxcXFadasWZf8dwUAsD5v5ZrkW9nGIA0AUM7JkyfVtm1bzZkz55zHDx065LbNnz9fNptNAwYMcDvvscceczvv3nvvdR3Lz89XcnKy4uPjtXnzZj3++ONKS0vT888/X6nXBgConnwp23gmDQAsrNQZoFITHo72tI3evXurd+/e5z0eExPj9vm///2vevTooWbNmrntDwsLK3dumcWLF6uoqEjz589XUFCQrrzySmVlZenJJ5/UmDFjPOovAMA3eCvXJN/KNmbSAMDCnLLJYcLm/Pkh7fz8fLetsLCwwn3Mzc3VO++8o1GjRpU7NnPmTNWrV0/t27fX448/rpKSEtexzMxMde3aVUFBQa59KSkpys7O1rFjxyrcLwCA9fhCrknezzYGaQBQjcTFxSk8PNy1paenV7jNRYsWKSwsTP3793fbf99992np0qX66KOPdNddd2nGjBmaNGmS63hOTo6io6PdvlP2OScnp8L9AgD4v8rINcn72Ua5IwBYmNllIQcOHJDdbnftDw4OrnDb8+fP1+DBgxUSEuK2f/z48a4/t2nTRkFBQbrrrruUnp5uys8FAPgeX8g1yfvZxiANACzM4bT
"text/plain": [
"<Figure size 1000x400 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n",
")\n",
"\n",
"for index in range(0, len(optimized_metrics)):\n",
" c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n",
" disp = ConfusionMatrixDisplay(\n",
" confusion_matrix=c_matrix, display_labels=[\"Less\", \"More\"]\n",
" ).plot(ax=ax.flat[index])\n",
"\n",
"plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"В желтом квадрате мы видим значение 2173, что обозначает количество правильно классифицированных объектов, отнесенных к классу \"Less\". Это свидетельствует о том, что модель успешно идентифицирует объекты этого класса, минимизируя количество ложных положительных срабатываний.\n",
"\n",
"В правом нижнем жёлтом квадрате значение 2150 указывает на количество правильно классифицированных объектов, отнесенных к классу \"More\". Это также является показателем высокой точности модели в определении объектов данного класса."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Задача регресии: предсказание цены дома (price).\n",
"\n",
"Описание: Оценить, какая будет цена дома (price) на основе исторических данных о характеристиках домов, таких как площадь. Целевая переменная: Цена дома (price). (среднее значение)"
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Среднее значение поля: 2079.8997362698374\n",
" id date price bedrooms bathrooms sqft_living \\\n",
"0 7129300520 20141013T000000 221900.0 3 1.00 1180 \n",
"1 6414100192 20141209T000000 538000.0 3 2.25 2570 \n",
"2 5631500400 20150225T000000 180000.0 2 1.00 770 \n",
"3 2487200875 20141209T000000 604000.0 4 3.00 1960 \n",
"4 1954400510 20150218T000000 510000.0 3 2.00 1680 \n",
"\n",
" sqft_lot floors waterfront view ... yr_built yr_renovated zipcode \\\n",
"0 5650 1.0 0 0 ... 1955 0 98178 \n",
"1 7242 2.0 0 0 ... 1951 1991 98125 \n",
"2 10000 1.0 0 0 ... 1933 0 98028 \n",
"3 5000 1.0 0 0 ... 1965 0 98136 \n",
"4 8080 1.0 0 0 ... 1987 0 98074 \n",
"\n",
" lat long sqft_living15 sqft_lot15 above_median_price \\\n",
"0 47.5112 -122.257 1340 5650 0 \n",
"1 47.7210 -122.319 1690 7639 1 \n",
"2 47.7379 -122.233 2720 8062 0 \n",
"3 47.5208 -122.393 1360 5000 1 \n",
"4 47.6168 -122.045 1800 7503 1 \n",
"\n",
" price_category average_price \n",
"0 0 0 \n",
"1 1 1 \n",
"2 0 0 \n",
"3 1 0 \n",
"4 1 0 \n",
"\n",
"[5 rows x 24 columns]\n",
"Статистическое описание DataFrame:\n",
" id price bedrooms bathrooms sqft_living \\\n",
"count 2.161300e+04 2.161300e+04 21613.000000 21613.000000 21613.000000 \n",
"mean 4.580302e+09 5.400881e+05 3.370842 2.114757 2079.899736 \n",
"std 2.876566e+09 3.671272e+05 0.930062 0.770163 918.440897 \n",
"min 1.000102e+06 7.500000e+04 0.000000 0.000000 290.000000 \n",
"25% 2.123049e+09 3.219500e+05 3.000000 1.750000 1427.000000 \n",
"50% 3.904930e+09 4.500000e+05 3.000000 2.250000 1910.000000 \n",
"75% 7.308900e+09 6.450000e+05 4.000000 2.500000 2550.000000 \n",
"max 9.900000e+09 7.700000e+06 33.000000 8.000000 13540.000000 \n",
"\n",
" sqft_lot floors waterfront view condition \\\n",
"count 2.161300e+04 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean 1.510697e+04 1.494309 0.007542 0.234303 3.409430 \n",
"std 4.142051e+04 0.539989 0.086517 0.766318 0.650743 \n",
"min 5.200000e+02 1.000000 0.000000 0.000000 1.000000 \n",
"25% 5.040000e+03 1.000000 0.000000 0.000000 3.000000 \n",
"50% 7.618000e+03 1.500000 0.000000 0.000000 3.000000 \n",
"75% 1.068800e+04 2.000000 0.000000 0.000000 4.000000 \n",
"max 1.651359e+06 3.500000 1.000000 4.000000 5.000000 \n",
"\n",
" ... sqft_basement yr_built yr_renovated zipcode \\\n",
"count ... 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean ... 291.509045 1971.005136 84.402258 98077.939805 \n",
"std ... 442.575043 29.373411 401.679240 53.505026 \n",
"min ... 0.000000 1900.000000 0.000000 98001.000000 \n",
"25% ... 0.000000 1951.000000 0.000000 98033.000000 \n",
"50% ... 0.000000 1975.000000 0.000000 98065.000000 \n",
"75% ... 560.000000 1997.000000 0.000000 98118.000000 \n",
"max ... 4820.000000 2015.000000 2015.000000 98199.000000 \n",
"\n",
" lat long sqft_living15 sqft_lot15 \\\n",
"count 21613.000000 21613.000000 21613.000000 21613.000000 \n",
"mean 47.560053 -122.213896 1986.552492 12768.455652 \n",
"std 0.138564 0.140828 685.391304 27304.179631 \n",
"min 47.155900 -122.519000 399.000000 651.000000 \n",
"25% 47.471000 -122.328000 1490.000000 5100.000000 \n",
"50% 47.571800 -122.230000 1840.000000 7620.000000 \n",
"75% 47.678000 -122.125000 2360.000000 10083.000000 \n",
"max 47.777600 -121.315000 6210.000000 871200.000000 \n",
"\n",
" above_median_price average_price \n",
"count 21613.000000 21613.00000 \n",
"mean 0.497340 0.42752 \n",
"std 0.500004 0.49473 \n",
"min 0.000000 0.00000 \n",
"25% 0.000000 0.00000 \n",
"50% 0.000000 0.00000 \n",
"75% 1.000000 1.00000 \n",
"max 1.000000 1.00000 \n",
"\n",
"[8 rows x 22 columns]\n"
]
}
],
"source": [
"import pandas as pd\n",
"from sklearn import set_config\n",
"\n",
"set_config(transform_output=\"pandas\")\n",
"\n",
"# Опция для настройки генерации случайных чисел (если это нужно для других частей кода)\n",
"random_state = 42\n",
"\n",
"# Вычисление среднего значения поля \"Close\"\n",
"average_price = df['sqft_living'].mean()\n",
"print(f\"Среднее значение поля: {average_price}\")\n",
"\n",
"# Создание новой колонки, указывающей, выше или ниже среднего значение цена закрытия\n",
"df['average_price'] = (df['sqft_living'] > average_price).astype(int)\n",
"\n",
"# Удаление последней строки, где нет значения для следующего дня\n",
"df.dropna(inplace=True)\n",
"\n",
"# Вывод DataFrame с новой колонкой\n",
"print(df.head())\n",
"\n",
"# Примерный анализ данных\n",
"print(\"Статистическое описание DataFrame:\")\n",
"print(df.describe())"
]
},
{
"cell_type": "code",
"execution_count": 167,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'X_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6325</th>\n",
" <td>5467910190</td>\n",
" <td>20140527T000000</td>\n",
" <td>325000.0</td>\n",
" <td>3</td>\n",
" <td>1.75</td>\n",
" <td>1780</td>\n",
" <td>13095</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1983</td>\n",
" <td>0</td>\n",
" <td>98042</td>\n",
" <td>47.3670</td>\n",
" <td>-122.152</td>\n",
" <td>2750</td>\n",
" <td>13095</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13473</th>\n",
" <td>9331800580</td>\n",
" <td>20150310T000000</td>\n",
" <td>257000.0</td>\n",
" <td>2</td>\n",
" <td>1.00</td>\n",
" <td>1000</td>\n",
" <td>3700</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>200</td>\n",
" <td>1929</td>\n",
" <td>0</td>\n",
" <td>98118</td>\n",
" <td>47.5520</td>\n",
" <td>-122.290</td>\n",
" <td>1270</td>\n",
" <td>5000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17614</th>\n",
" <td>2407000405</td>\n",
" <td>20150226T000000</td>\n",
" <td>228500.0</td>\n",
" <td>3</td>\n",
" <td>1.00</td>\n",
" <td>1080</td>\n",
" <td>7486</td>\n",
" <td>1.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>90</td>\n",
" <td>1942</td>\n",
" <td>0</td>\n",
" <td>98146</td>\n",
" <td>47.4838</td>\n",
" <td>-122.335</td>\n",
" <td>1170</td>\n",
" <td>7800</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16970</th>\n",
" <td>5466700290</td>\n",
" <td>20150108T000000</td>\n",
" <td>288000.0</td>\n",
" <td>3</td>\n",
" <td>2.25</td>\n",
" <td>2090</td>\n",
" <td>7500</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>810</td>\n",
" <td>1977</td>\n",
" <td>0</td>\n",
" <td>98031</td>\n",
" <td>47.3951</td>\n",
" <td>-122.172</td>\n",
" <td>1800</td>\n",
" <td>7350</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20868</th>\n",
" <td>3026059361</td>\n",
" <td>20150417T000000</td>\n",
" <td>479000.0</td>\n",
" <td>2</td>\n",
" <td>2.50</td>\n",
" <td>1741</td>\n",
" <td>1439</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>295</td>\n",
" <td>2007</td>\n",
" <td>0</td>\n",
" <td>98034</td>\n",
" <td>47.7043</td>\n",
" <td>-122.209</td>\n",
" <td>2090</td>\n",
" <td>10454</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11964</th>\n",
" <td>5272200045</td>\n",
" <td>20141113T000000</td>\n",
" <td>378000.0</td>\n",
" <td>3</td>\n",
" <td>1.50</td>\n",
" <td>1000</td>\n",
" <td>6914</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1947</td>\n",
" <td>0</td>\n",
" <td>98125</td>\n",
" <td>47.7144</td>\n",
" <td>-122.319</td>\n",
" <td>1000</td>\n",
" <td>6947</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21575</th>\n",
" <td>9578500790</td>\n",
" <td>20141111T000000</td>\n",
" <td>399950.0</td>\n",
" <td>3</td>\n",
" <td>2.50</td>\n",
" <td>3087</td>\n",
" <td>5002</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2014</td>\n",
" <td>0</td>\n",
" <td>98023</td>\n",
" <td>47.2974</td>\n",
" <td>-122.349</td>\n",
" <td>2927</td>\n",
" <td>5183</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5390</th>\n",
" <td>7202350480</td>\n",
" <td>20140930T000000</td>\n",
" <td>575000.0</td>\n",
" <td>3</td>\n",
" <td>2.50</td>\n",
" <td>2120</td>\n",
" <td>4780</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2004</td>\n",
" <td>0</td>\n",
" <td>98053</td>\n",
" <td>47.6810</td>\n",
" <td>-122.032</td>\n",
" <td>1690</td>\n",
" <td>2650</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>1723049033</td>\n",
" <td>20140620T000000</td>\n",
" <td>245000.0</td>\n",
" <td>1</td>\n",
" <td>0.75</td>\n",
" <td>380</td>\n",
" <td>15000</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1963</td>\n",
" <td>0</td>\n",
" <td>98168</td>\n",
" <td>47.4810</td>\n",
" <td>-122.323</td>\n",
" <td>1170</td>\n",
" <td>15000</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>6147650280</td>\n",
" <td>20150325T000000</td>\n",
" <td>315000.0</td>\n",
" <td>4</td>\n",
" <td>2.50</td>\n",
" <td>3130</td>\n",
" <td>5999</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2006</td>\n",
" <td>0</td>\n",
" <td>98042</td>\n",
" <td>47.3837</td>\n",
" <td>-122.099</td>\n",
" <td>3020</td>\n",
" <td>5997</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>17290 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms \\\n",
"6325 5467910190 20140527T000000 325000.0 3 1.75 \n",
"13473 9331800580 20150310T000000 257000.0 2 1.00 \n",
"17614 2407000405 20150226T000000 228500.0 3 1.00 \n",
"16970 5466700290 20150108T000000 288000.0 3 2.25 \n",
"20868 3026059361 20150417T000000 479000.0 2 2.50 \n",
"... ... ... ... ... ... \n",
"11964 5272200045 20141113T000000 378000.0 3 1.50 \n",
"21575 9578500790 20141111T000000 399950.0 3 2.50 \n",
"5390 7202350480 20140930T000000 575000.0 3 2.50 \n",
"860 1723049033 20140620T000000 245000.0 1 0.75 \n",
"15795 6147650280 20150325T000000 315000.0 4 2.50 \n",
"\n",
" sqft_living sqft_lot floors waterfront view ... sqft_basement \\\n",
"6325 1780 13095 1.0 0 0 ... 0 \n",
"13473 1000 3700 1.0 0 0 ... 200 \n",
"17614 1080 7486 1.5 0 0 ... 90 \n",
"16970 2090 7500 1.0 0 0 ... 810 \n",
"20868 1741 1439 2.0 0 0 ... 295 \n",
"... ... ... ... ... ... ... ... \n",
"11964 1000 6914 1.0 0 0 ... 0 \n",
"21575 3087 5002 2.0 0 0 ... 0 \n",
"5390 2120 4780 2.0 0 0 ... 0 \n",
"860 380 15000 1.0 0 0 ... 0 \n",
"15795 3130 5999 2.0 0 0 ... 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"6325 1983 0 98042 47.3670 -122.152 2750 \n",
"13473 1929 0 98118 47.5520 -122.290 1270 \n",
"17614 1942 0 98146 47.4838 -122.335 1170 \n",
"16970 1977 0 98031 47.3951 -122.172 1800 \n",
"20868 2007 0 98034 47.7043 -122.209 2090 \n",
"... ... ... ... ... ... ... \n",
"11964 1947 0 98125 47.7144 -122.319 1000 \n",
"21575 2014 0 98023 47.2974 -122.349 2927 \n",
"5390 2004 0 98053 47.6810 -122.032 1690 \n",
"860 1963 0 98168 47.4810 -122.323 1170 \n",
"15795 2006 0 98042 47.3837 -122.099 3020 \n",
"\n",
" sqft_lot15 above_median_price price_category \n",
"6325 13095 0 1 \n",
"13473 5000 0 0 \n",
"17614 7800 0 0 \n",
"16970 7350 0 0 \n",
"20868 10454 1 1 \n",
"... ... ... ... \n",
"11964 6947 0 1 \n",
"21575 5183 0 1 \n",
"5390 2650 1 1 \n",
"860 15000 0 0 \n",
"15795 5997 0 1 \n",
"\n",
"[17290 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_train'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>average_price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6325</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13473</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17614</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16970</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20868</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11964</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21575</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5390</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>860</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15795</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>17290 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" average_price\n",
"6325 0\n",
"13473 0\n",
"17614 0\n",
"16970 1\n",
"20868 0\n",
"... ...\n",
"11964 0\n",
"21575 1\n",
"5390 1\n",
"860 0\n",
"15795 1\n",
"\n",
"[17290 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'X_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>id</th>\n",
" <th>date</th>\n",
" <th>price</th>\n",
" <th>bedrooms</th>\n",
" <th>bathrooms</th>\n",
" <th>sqft_living</th>\n",
" <th>sqft_lot</th>\n",
" <th>floors</th>\n",
" <th>waterfront</th>\n",
" <th>view</th>\n",
" <th>...</th>\n",
" <th>sqft_basement</th>\n",
" <th>yr_built</th>\n",
" <th>yr_renovated</th>\n",
" <th>zipcode</th>\n",
" <th>lat</th>\n",
" <th>long</th>\n",
" <th>sqft_living15</th>\n",
" <th>sqft_lot15</th>\n",
" <th>above_median_price</th>\n",
" <th>price_category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>735</th>\n",
" <td>2591820310</td>\n",
" <td>20141006T000000</td>\n",
" <td>365000.0</td>\n",
" <td>4</td>\n",
" <td>2.25</td>\n",
" <td>2070</td>\n",
" <td>8893</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1986</td>\n",
" <td>0</td>\n",
" <td>98058</td>\n",
" <td>47.4388</td>\n",
" <td>-122.162</td>\n",
" <td>2390</td>\n",
" <td>7700</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2830</th>\n",
" <td>7974200820</td>\n",
" <td>20140821T000000</td>\n",
" <td>865000.0</td>\n",
" <td>5</td>\n",
" <td>3.00</td>\n",
" <td>2900</td>\n",
" <td>6730</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>1070</td>\n",
" <td>1977</td>\n",
" <td>0</td>\n",
" <td>98115</td>\n",
" <td>47.6784</td>\n",
" <td>-122.285</td>\n",
" <td>2370</td>\n",
" <td>6283</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4106</th>\n",
" <td>7701450110</td>\n",
" <td>20140815T000000</td>\n",
" <td>1038000.0</td>\n",
" <td>4</td>\n",
" <td>2.50</td>\n",
" <td>3770</td>\n",
" <td>10893</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1997</td>\n",
" <td>0</td>\n",
" <td>98006</td>\n",
" <td>47.5646</td>\n",
" <td>-122.129</td>\n",
" <td>3710</td>\n",
" <td>9685</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16218</th>\n",
" <td>9522300010</td>\n",
" <td>20150331T000000</td>\n",
" <td>1490000.0</td>\n",
" <td>3</td>\n",
" <td>3.50</td>\n",
" <td>4560</td>\n",
" <td>14608</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1990</td>\n",
" <td>0</td>\n",
" <td>98034</td>\n",
" <td>47.6995</td>\n",
" <td>-122.228</td>\n",
" <td>4050</td>\n",
" <td>14226</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19964</th>\n",
" <td>9510861140</td>\n",
" <td>20140714T000000</td>\n",
" <td>711000.0</td>\n",
" <td>3</td>\n",
" <td>2.50</td>\n",
" <td>2550</td>\n",
" <td>5376</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2004</td>\n",
" <td>0</td>\n",
" <td>98052</td>\n",
" <td>47.6647</td>\n",
" <td>-122.083</td>\n",
" <td>2250</td>\n",
" <td>4050</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13674</th>\n",
" <td>6163900333</td>\n",
" <td>20141110T000000</td>\n",
" <td>338000.0</td>\n",
" <td>3</td>\n",
" <td>1.75</td>\n",
" <td>1250</td>\n",
" <td>7710</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1947</td>\n",
" <td>0</td>\n",
" <td>98155</td>\n",
" <td>47.7623</td>\n",
" <td>-122.317</td>\n",
" <td>1340</td>\n",
" <td>7710</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20377</th>\n",
" <td>3528960020</td>\n",
" <td>20140708T000000</td>\n",
" <td>673000.0</td>\n",
" <td>3</td>\n",
" <td>2.75</td>\n",
" <td>2830</td>\n",
" <td>3496</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2012</td>\n",
" <td>0</td>\n",
" <td>98029</td>\n",
" <td>47.5606</td>\n",
" <td>-122.011</td>\n",
" <td>2160</td>\n",
" <td>3501</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8805</th>\n",
" <td>1687000220</td>\n",
" <td>20141016T000000</td>\n",
" <td>285000.0</td>\n",
" <td>4</td>\n",
" <td>2.50</td>\n",
" <td>2434</td>\n",
" <td>4400</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2007</td>\n",
" <td>0</td>\n",
" <td>98001</td>\n",
" <td>47.2874</td>\n",
" <td>-122.283</td>\n",
" <td>2434</td>\n",
" <td>4400</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10168</th>\n",
" <td>4141400030</td>\n",
" <td>20141201T000000</td>\n",
" <td>605000.0</td>\n",
" <td>4</td>\n",
" <td>1.75</td>\n",
" <td>2250</td>\n",
" <td>10108</td>\n",
" <td>1.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>1967</td>\n",
" <td>0</td>\n",
" <td>98008</td>\n",
" <td>47.5922</td>\n",
" <td>-122.118</td>\n",
" <td>2050</td>\n",
" <td>9750</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2522</th>\n",
" <td>1822500160</td>\n",
" <td>20141212T000000</td>\n",
" <td>356500.0</td>\n",
" <td>4</td>\n",
" <td>2.50</td>\n",
" <td>2570</td>\n",
" <td>11473</td>\n",
" <td>2.0</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>...</td>\n",
" <td>0</td>\n",
" <td>2008</td>\n",
" <td>0</td>\n",
" <td>98003</td>\n",
" <td>47.2809</td>\n",
" <td>-122.296</td>\n",
" <td>2430</td>\n",
" <td>5997</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4323 rows × 23 columns</p>\n",
"</div>"
],
"text/plain": [
" id date price bedrooms bathrooms \\\n",
"735 2591820310 20141006T000000 365000.0 4 2.25 \n",
"2830 7974200820 20140821T000000 865000.0 5 3.00 \n",
"4106 7701450110 20140815T000000 1038000.0 4 2.50 \n",
"16218 9522300010 20150331T000000 1490000.0 3 3.50 \n",
"19964 9510861140 20140714T000000 711000.0 3 2.50 \n",
"... ... ... ... ... ... \n",
"13674 6163900333 20141110T000000 338000.0 3 1.75 \n",
"20377 3528960020 20140708T000000 673000.0 3 2.75 \n",
"8805 1687000220 20141016T000000 285000.0 4 2.50 \n",
"10168 4141400030 20141201T000000 605000.0 4 1.75 \n",
"2522 1822500160 20141212T000000 356500.0 4 2.50 \n",
"\n",
" sqft_living sqft_lot floors waterfront view ... sqft_basement \\\n",
"735 2070 8893 2.0 0 0 ... 0 \n",
"2830 2900 6730 1.0 0 0 ... 1070 \n",
"4106 3770 10893 2.0 0 2 ... 0 \n",
"16218 4560 14608 2.0 0 2 ... 0 \n",
"19964 2550 5376 2.0 0 0 ... 0 \n",
"... ... ... ... ... ... ... ... \n",
"13674 1250 7710 1.0 0 0 ... 0 \n",
"20377 2830 3496 2.0 0 0 ... 0 \n",
"8805 2434 4400 2.0 0 0 ... 0 \n",
"10168 2250 10108 1.0 0 0 ... 0 \n",
"2522 2570 11473 2.0 0 0 ... 0 \n",
"\n",
" yr_built yr_renovated zipcode lat long sqft_living15 \\\n",
"735 1986 0 98058 47.4388 -122.162 2390 \n",
"2830 1977 0 98115 47.6784 -122.285 2370 \n",
"4106 1997 0 98006 47.5646 -122.129 3710 \n",
"16218 1990 0 98034 47.6995 -122.228 4050 \n",
"19964 2004 0 98052 47.6647 -122.083 2250 \n",
"... ... ... ... ... ... ... \n",
"13674 1947 0 98155 47.7623 -122.317 1340 \n",
"20377 2012 0 98029 47.5606 -122.011 2160 \n",
"8805 2007 0 98001 47.2874 -122.283 2434 \n",
"10168 1967 0 98008 47.5922 -122.118 2050 \n",
"2522 2008 0 98003 47.2809 -122.296 2430 \n",
"\n",
" sqft_lot15 above_median_price price_category \n",
"735 7700 0 1 \n",
"2830 6283 1 2 \n",
"4106 9685 1 2 \n",
"16218 14226 1 2 \n",
"19964 4050 1 2 \n",
"... ... ... ... \n",
"13674 7710 0 1 \n",
"20377 3501 1 1 \n",
"8805 4400 0 0 \n",
"10168 9750 1 1 \n",
"2522 5997 0 1 \n",
"\n",
"[4323 rows x 23 columns]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"'y_test'"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>average_price</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>735</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2830</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4106</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16218</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19964</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13674</th>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20377</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8805</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10168</th>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2522</th>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>4323 rows × 1 columns</p>\n",
"</div>"
],
"text/plain": [
" average_price\n",
"735 0\n",
"2830 1\n",
"4106 1\n",
"16218 1\n",
"19964 1\n",
"... ...\n",
"13674 0\n",
"20377 1\n",
"8805 1\n",
"10168 1\n",
"2522 1\n",
"\n",
"[4323 rows x 1 columns]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from typing import Tuple\n",
"from pandas import DataFrame\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def split_into_train_test(\n",
" df_input: DataFrame,\n",
" target_colname: str = \"average_price\",\n",
" frac_train: float = 0.8,\n",
" random_state: int = None,\n",
") -> Tuple[DataFrame, DataFrame, DataFrame, DataFrame]:\n",
" \n",
" if not (0 < frac_train < 1):\n",
" raise ValueError(\"Fraction must be between 0 and 1.\")\n",
" \n",
" # Проверка наличия целевого признака\n",
" if target_colname not in df_input.columns:\n",
" raise ValueError(f\"{target_colname} is not a column in the DataFrame.\")\n",
" \n",
" # Разделяем данные на признаки и целевую переменную\n",
" X = df_input.drop(columns=[target_colname]) # Признаки\n",
" y = df_input[[target_colname]] # Целевая переменная\n",
"\n",
" # Разделяем данные на обучающую и тестовую выборки\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y,\n",
" test_size=(1.0 - frac_train),\n",
" random_state=random_state\n",
" )\n",
" \n",
" return X_train, X_test, y_train, y_test\n",
"\n",
"# Применение функции для разделения данных\n",
"X_train, X_test, y_train, y_test = split_into_train_test(\n",
" df, \n",
" target_colname=\"average_price\", \n",
" frac_train=0.8, \n",
" random_state=42 # Убедитесь, что вы задали нужное значение random_state\n",
")\n",
"\n",
"# Для отображения результатов\n",
"display(\"X_train\", X_train)\n",
"display(\"y_train\", y_train)\n",
"\n",
"display(\"X_test\", X_test)\n",
"display(\"y_test\", y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование конвейера для решения задачи регрессии"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.base import BaseEstimator, TransformerMixin\n",
"from sklearn.compose import ColumnTransformer\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.impute import SimpleImputer\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from sklearn.ensemble import RandomForestRegressor # Пример регрессионной модели\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.pipeline import make_pipeline\n",
"\n",
"class HouseFeatures(BaseEstimator, TransformerMixin):\n",
" def __init__(self):\n",
" pass\n",
" def fit(self, X, y=None):\n",
" return self\n",
" def transform(self, X, y=None):\n",
" # Создание новых признаков\n",
" X = X.copy()\n",
" X[\"Square\"] = X[\"sqft_living\"] / X[\"sqft_lot\"]\n",
" return X\n",
" def get_feature_names_out(self, features_in):\n",
" # Добавление имен новых признаков\n",
" new_features = [\"Square\"]\n",
" return np.append(features_in, new_features, axis=0)\n",
"\n",
"# Указываем столбцы, которые нужно удалить и обрабатывать\n",
"columns_to_drop = [\"date\"]\n",
"num_columns = [\"bathrooms\", \"floors\", \"waterfront\", \"view\"]\n",
"cat_columns = [] \n",
"\n",
"# Определяем предобработку для численных данных\n",
"num_imputer = SimpleImputer(strategy=\"median\")\n",
"num_scaler = StandardScaler()\n",
"preprocessing_num = Pipeline(\n",
" [\n",
" (\"imputer\", num_imputer),\n",
" (\"scaler\", num_scaler),\n",
" ]\n",
")\n",
"\n",
"# Определяем предобработку для категориальных данных\n",
"cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=\"unknown\")\n",
"cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n",
"preprocessing_cat = Pipeline(\n",
" [\n",
" (\"imputer\", cat_imputer),\n",
" (\"encoder\", cat_encoder),\n",
" ]\n",
")\n",
"\n",
"# Подготовка признаков с использованием ColumnTransformer\n",
"features_preprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_num\", preprocessing_num, num_columns),\n",
" (\"preprocessing_cat\", preprocessing_cat, cat_columns),\n",
" ],\n",
" remainder=\"passthrough\"\n",
")\n",
"\n",
"# Удаление нежелательных столбцов\n",
"drop_columns = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"drop_columns\", \"drop\", columns_to_drop),\n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Постобработка признаков\n",
"features_postprocessing = ColumnTransformer(\n",
" verbose_feature_names_out=False,\n",
" transformers=[\n",
" (\"preprocessing_cat\", preprocessing_cat, [\"price_category\"]), \n",
" ],\n",
" remainder=\"passthrough\",\n",
")\n",
"\n",
"# Создание окончательного конвейера\n",
"pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"custom_features\", HouseFeatures()),\n",
" (\"model\", RandomForestRegressor()) # Выбор модели для обучения\n",
" ]\n",
")\n",
"\n",
"# Использование конвейера\n",
"def train_pipeline(X, y):\n",
" pipeline.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Формирование набора моделей для регрессии \n",
"Определение перечня алгоритмов решения задачи аппроксимации (регрессии)"
]
},
{
"cell_type": "code",
"execution_count": 169,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Формирование набора моделей для регрессии"
]
},
{
"cell_type": "code",
"execution_count": 170,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Random Forest: Mean Score = 1.0, Standard Deviation = 0.0\n",
"Linear Regression: Mean Score = 0.6396438910587428, Standard Deviation = 0.006348300027629372\n",
"Gradient Boosting: Mean Score = 0.9999999992943781, Standard Deviation = 6.609300428326041e-14\n",
"Support Vector Regression: Mean Score = -0.4335265257004087, Standard Deviation = 0.012071668862264313\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"from sklearn.svm import SVR\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"def train_multiple_models(X, y, models):\n",
" results = {}\n",
"\n",
" # Преобразуем y в одномерный массив numpy только при необходимости\n",
" if hasattr(y, 'values'):\n",
" y = y.values.ravel() # Если y - DataFrame, преобразуем в numpy array\n",
" else:\n",
" y = y.ravel() # Если y - numpy array, просто используем ravel()\n",
"\n",
" for model_name, model in models.items():\n",
" # Создаем конвейер для каждой модели\n",
" model_pipeline = Pipeline(\n",
" [\n",
" (\"features_preprocessing\", features_preprocessing),\n",
" (\"drop_columns\", drop_columns),\n",
" (\"model\", model) # Используем текущую модель\n",
" ]\n",
" )\n",
" \n",
" # Обучаем модель и вычисляем кросс-валидацию\n",
" scores = cross_val_score(model_pipeline, X, y, cv=5, error_score='raise') # 5-кратная кросс-валидация\n",
" results[model_name] = {\n",
" \"mean_score\": scores.mean(),\n",
" \"std_dev\": scores.std()\n",
" }\n",
" \n",
" return results\n",
"\n",
"models = {\n",
" \"Random Forest\": RandomForestRegressor(),\n",
" \"Linear Regression\": LinearRegression(),\n",
" \"Gradient Boosting\": GradientBoostingRegressor(),\n",
" \"Support Vector Regression\": SVR()\n",
"}\n",
"\n",
"results = train_multiple_models(X_train, y_train, models)\n",
"\n",
"# Вывод результатов\n",
"for model_name, scores in results.items():\n",
" print(f\"{model_name}: Mean Score = {scores['mean_score']}, Standard Deviation = {scores['std_dev']}\")"
]
},
{
"cell_type": "code",
"execution_count": 171,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"MSE (train): 0.24060150375939848\n",
"MSE (test): 0.23455933379597502\n",
"MAE (train): 0.24060150375939848\n",
"MAE (test): 0.23455933379597502\n",
"R2 (train): 0.015780807725750634\n",
"R2 (test): 0.045807954005714024\n",
"STD (train): 0.48387852043102103\n",
"STD (test): 0.4780359236045559\n",
"----------------------------------------\n",
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\MII\\laboratory\\mai\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.11596298438403702\n",
"MSE (test): 0.11265325005783021\n",
"MAE (train): 0.11596298438403702\n",
"MAE (test): 0.11265325005783021\n",
"R2 (train): 0.5256347402620505\n",
"R2 (test): 0.541724332939628\n",
"STD (train): 0.3405113334365492\n",
"STD (test): 0.3356321137822519\n",
"----------------------------------------\n",
"Model: decision_tree\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: knn\n",
"MSE (train): 0.1949681897050318\n",
"MSE (test): 0.27989821882951654\n",
"MAE (train): 0.1949681897050318\n",
"MAE (test): 0.27989821882951654\n",
"R2 (train): 0.20245122664507342\n",
"R2 (test): -0.13863153417464114\n",
"STD (train): 0.43948973967967464\n",
"STD (test): 0.5264647910268833\n",
"----------------------------------------\n",
"Model: naive_bayes\n",
"MSE (train): 0.26928860613071137\n",
"MSE (test): 0.2690261392551469\n",
"MAE (train): 0.26928860613071137\n",
"MAE (test): 0.2690261392551469\n",
"R2 (train): -0.10156840366079445\n",
"R2 (test): -0.09440369772322943\n",
"STD (train): 0.47316941542228536\n",
"STD (test): 0.47206502931490235\n",
"----------------------------------------\n",
"Model: gradient_boosting\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: random_forest\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: mlp\n",
"MSE (train): 0.4253903990746096\n",
"MSE (test): 0.4353458246588018\n",
"MAE (train): 0.4253903990746096\n",
"MAE (test): 0.4353458246588018\n",
"R2 (train): -0.7401279228791116\n",
"R2 (test): -0.7709954936501442\n",
"STD (train): 0.4959884986820156\n",
"STD (test): 0.49782384226978177\n",
"----------------------------------------\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Проверка наличия необходимых переменных\n",
"if 'class_models' not in locals():\n",
" raise ValueError(\"class_models is not defined\")\n",
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
" raise ValueError(\"Train/test data is not defined\")\n",
"\n",
"\n",
"y_train = np.ravel(y_train) \n",
"y_test = np.ravel(y_test) \n",
"\n",
"# Инициализация списка для хранения результатов\n",
"results = []\n",
"\n",
"# Проход по моделям и оценка их качества\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" # Извлечение модели из словаря\n",
" model = class_models[model_name][\"model\"]\n",
" \n",
" # Создание пайплайна\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение пайплайна и предсказаний\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик для регрессии\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Дополнительные метрики\n",
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
"\n",
" # Вывод результатов для текущей модели\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
" print(\"-\" * 40) # Разделитель для разных моделей"
]
},
{
"cell_type": "code",
"execution_count": 172,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import PolynomialFeatures\n",
"from sklearn import linear_model, tree, neighbors, ensemble, neural_network\n",
"\n",
"random_state = 9\n",
"\n",
"models = {\n",
" \"linear\": {\"model\": linear_model.LinearRegression(n_jobs=-1)},\n",
" \"linear_poly\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(degree=2),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"linear_interact\": {\n",
" \"model\": make_pipeline(\n",
" PolynomialFeatures(interaction_only=True),\n",
" linear_model.LinearRegression(fit_intercept=False, n_jobs=-1),\n",
" )\n",
" },\n",
" \"ridge\": {\"model\": linear_model.RidgeCV()},\n",
" \"decision_tree\": {\n",
" \"model\": tree.DecisionTreeRegressor(max_depth=7, random_state=random_state)\n",
" },\n",
" \"knn\": {\"model\": neighbors.KNeighborsRegressor(n_neighbors=7, n_jobs=-1)},\n",
" \"random_forest\": {\n",
" \"model\": ensemble.RandomForestRegressor(\n",
" max_depth=7, random_state=random_state, n_jobs=-1\n",
" )\n",
" },\n",
" \"mlp\": {\n",
" \"model\": neural_network.MLPRegressor(\n",
" activation=\"tanh\",\n",
" hidden_layer_sizes=(3,),\n",
" max_iter=500,\n",
" early_stopping=True,\n",
" random_state=random_state,\n",
" )\n",
" },\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Обучение и оценка моделей с помощью различных алгоритмов"
]
},
{
"cell_type": "code",
"execution_count": 173,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: logistic\n",
"MSE (train): 0.24060150375939848\n",
"MSE (test): 0.23455933379597502\n",
"MAE (train): 0.24060150375939848\n",
"MAE (test): 0.23455933379597502\n",
"R2 (train): 0.015780807725750634\n",
"R2 (test): 0.045807954005714024\n",
"STD (train): 0.48387852043102103\n",
"STD (test): 0.4780359236045559\n",
"----------------------------------------\n",
"Model: ridge\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"e:\\MII\\laboratory\\mai\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
"STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
"\n",
"Increase the number of iterations (max_iter) or scale the data as shown in:\n",
" https://scikit-learn.org/stable/modules/preprocessing.html\n",
"Please also refer to the documentation for alternative solver options:\n",
" https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
" n_iter_i = _check_optimize_result(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE (train): 0.11596298438403702\n",
"MSE (test): 0.11265325005783021\n",
"MAE (train): 0.11596298438403702\n",
"MAE (test): 0.11265325005783021\n",
"R2 (train): 0.5256347402620505\n",
"R2 (test): 0.541724332939628\n",
"STD (train): 0.3405113334365492\n",
"STD (test): 0.3356321137822519\n",
"----------------------------------------\n",
"Model: decision_tree\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: knn\n",
"MSE (train): 0.1949681897050318\n",
"MSE (test): 0.27989821882951654\n",
"MAE (train): 0.1949681897050318\n",
"MAE (test): 0.27989821882951654\n",
"R2 (train): 0.20245122664507342\n",
"R2 (test): -0.13863153417464114\n",
"STD (train): 0.43948973967967464\n",
"STD (test): 0.5264647910268833\n",
"----------------------------------------\n",
"Model: naive_bayes\n",
"MSE (train): 0.26928860613071137\n",
"MSE (test): 0.2690261392551469\n",
"MAE (train): 0.26928860613071137\n",
"MAE (test): 0.2690261392551469\n",
"R2 (train): -0.10156840366079445\n",
"R2 (test): -0.09440369772322943\n",
"STD (train): 0.47316941542228536\n",
"STD (test): 0.47206502931490235\n",
"----------------------------------------\n",
"Model: gradient_boosting\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: random_forest\n",
"MSE (train): 0.0\n",
"MSE (test): 0.0\n",
"MAE (train): 0.0\n",
"MAE (test): 0.0\n",
"R2 (train): 1.0\n",
"R2 (test): 1.0\n",
"STD (train): 0.0\n",
"STD (test): 0.0\n",
"----------------------------------------\n",
"Model: mlp\n",
"MSE (train): 0.4253903990746096\n",
"MSE (test): 0.4353458246588018\n",
"MAE (train): 0.4253903990746096\n",
"MAE (test): 0.4353458246588018\n",
"R2 (train): -0.7401279228791116\n",
"R2 (test): -0.7709954936501442\n",
"STD (train): 0.4959884986820156\n",
"STD (test): 0.49782384226978177\n",
"----------------------------------------\n"
]
}
],
"source": [
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.pipeline import Pipeline\n",
"\n",
"# Проверка наличия необходимых переменных\n",
"if 'class_models' not in locals():\n",
" raise ValueError(\"class_models is not defined\")\n",
"if 'X_train' not in locals() or 'X_test' not in locals() or 'y_train' not in locals() or 'y_test' not in locals():\n",
" raise ValueError(\"Train/test data is not defined\")\n",
"\n",
"\n",
"y_train = np.ravel(y_train) \n",
"y_test = np.ravel(y_test) \n",
"\n",
"# Инициализация списка для хранения результатов\n",
"results = []\n",
"\n",
"# Проход по моделям и оценка их качества\n",
"for model_name in class_models.keys():\n",
" print(f\"Model: {model_name}\")\n",
" \n",
" # Извлечение модели из словаря\n",
" model = class_models[model_name][\"model\"]\n",
" \n",
" # Создание пайплайна\n",
" model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n",
" \n",
" # Обучение модели\n",
" model_pipeline.fit(X_train, y_train)\n",
"\n",
" # Предсказание для обучающей и тестовой выборки\n",
" y_train_predict = model_pipeline.predict(X_train)\n",
" y_test_predict = model_pipeline.predict(X_test)\n",
"\n",
" # Сохранение пайплайна и предсказаний\n",
" class_models[model_name][\"pipeline\"] = model_pipeline\n",
" class_models[model_name][\"preds\"] = y_test_predict\n",
"\n",
" # Вычисление метрик для регрессии\n",
" class_models[model_name][\"MSE_train\"] = metrics.mean_squared_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MSE_test\"] = metrics.mean_squared_error(y_test, y_test_predict)\n",
" class_models[model_name][\"MAE_train\"] = metrics.mean_absolute_error(y_train, y_train_predict)\n",
" class_models[model_name][\"MAE_test\"] = metrics.mean_absolute_error(y_test, y_test_predict)\n",
" class_models[model_name][\"R2_train\"] = metrics.r2_score(y_train, y_train_predict)\n",
" class_models[model_name][\"R2_test\"] = metrics.r2_score(y_test, y_test_predict)\n",
"\n",
" # Дополнительные метрики\n",
" class_models[model_name][\"STD_train\"] = np.std(y_train - y_train_predict)\n",
" class_models[model_name][\"STD_test\"] = np.std(y_test - y_test_predict)\n",
"\n",
" # Вывод результатов для текущей модели\n",
" print(f\"MSE (train): {class_models[model_name]['MSE_train']}\")\n",
" print(f\"MSE (test): {class_models[model_name]['MSE_test']}\")\n",
" print(f\"MAE (train): {class_models[model_name]['MAE_train']}\")\n",
" print(f\"MAE (test): {class_models[model_name]['MAE_test']}\")\n",
" print(f\"R2 (train): {class_models[model_name]['R2_train']}\")\n",
" print(f\"R2 (test): {class_models[model_name]['R2_test']}\")\n",
" print(f\"STD (train): {class_models[model_name]['STD_train']}\")\n",
" print(f\"STD (test): {class_models[model_name]['STD_test']}\")\n",
" print(\"-\" * 40) # Разделитель для разных моделей"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Пример использования обученной модели (конвейера регрессии) для предсказания**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Подбор гиперпараметров методом поиска по сетке**"
]
},
{
"cell_type": "code",
"execution_count": 174,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n",
"Best parameters: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Best MSE: 0.14752641202600872\n"
]
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Convert the date column to a datetime object and extract numeric features\n",
"df['date'] = pd.to_datetime(df['date'], errors='coerce') # Coerce invalid dates to NaT\n",
"df.dropna(subset=['date'], inplace=True) # Drop rows with invalid dates\n",
"df['year'] = df['date'].dt.year\n",
"df['month'] = df['date'].dt.month\n",
"df['day'] = df['date'].dt.day\n",
"\n",
"# Prepare predictors and target\n",
"X = df[['yr_built', 'year', 'month', 'day', 'price', 'price_category']]\n",
"y = df['average_price']\n",
"\n",
"# Split data into training and testing sets\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"\n",
"# Define model and parameter grid\n",
"model = RandomForestRegressor()\n",
"param_grid = {\n",
" 'n_estimators': [50, 100, 200],\n",
" 'max_depth': [None, 10, 20, 30],\n",
" 'min_samples_split': [2, 5, 10]\n",
"}\n",
"\n",
"# Hyperparameter tuning with GridSearchCV\n",
"grid_search = GridSearchCV(estimator=model, param_grid=param_grid,\n",
" scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"# Fit the model\n",
"grid_search.fit(X_train, y_train)\n",
"\n",
"# Output the best parameters and score\n",
"print(\"Best parameters:\", grid_search.best_params_)\n",
"print(\"Best MSE:\", -grid_search.best_score_)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Обучение модели с новыми гиперпараметрами и сравнение новых и старых данных**"
]
},
{
"cell_type": "code",
"execution_count": 175,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 36 candidates, totalling 180 fits\n",
"Старые параметры: {'max_depth': 10, 'min_samples_split': 15, 'n_estimators': 200}\n",
"Лучший результат (MSE) на старых параметрах: 0.14727400921908354\n",
"\n",
"Новые параметры: {'max_depth': 10, 'min_samples_split': 10, 'n_estimators': 200}\n",
"Лучший результат (MSE) на новых параметрах: 0.148833681322309\n",
"Среднеквадратическая ошибка (MSE) на тестовых данных: 0.14451630134635543\n",
"Корень среднеквадратичной ошибки (RMSE) на тестовых данных: 0.3801529972870863\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1cAAAHWCAYAAACbsXOkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAABusklEQVR4nO3deXwN1//H8fdNyL4hJIImttpC7AQVrSVaRWhrbS1V3awNavmWoN821N7yrdIW1SrVoqqtfjX2ncROVdVWJJaSECQk8/vDL/frSkJu3DQSr+fjcR8yZ86c+cyYZPLJOXPGZBiGIQAAAADAA7HL7QAAAAAAID8guQIAAAAAGyC5AgAAAAAbILkCAAAAABsguQIAAAAAGyC5AgAAAAAbILkCAAAAABsguQIAAAAAGyC5AgAAAAAbILkCAAAAABsguQLysaNHj+q1115TmTJl5OTkJA8PDzVs2FDTpk3T9evXczu8R8batWtlMplkMpn05ZdfZlinYcOGMplMCgwMtChPTk7WtGnTVKNGDXl4eMjLy0tVqlTRq6++qt9++81cb+7cueZ9ZPTZunVrjh4jAACQCuR2AAByxo8//qgXXnhBjo6O6tatmwIDA5WcnKyNGzdqyJAhOnDggGbNmpXbYT5SnJyctGDBAr344osW5cePH9fmzZvl5OSUbpvnnntOP//8szp37qzevXvr5s2b+u2337RixQo1aNBAFStWtKg/duxYlS5dOl075cqVs+3BAACAdEiugHzo2LFj6tSpk/z9/bV69WoVL17cvK5Pnz76448/9OOPP+ZihI+mZ555RsuXL9eFCxfk7e1tLl+wYIF8fHxUvnx5Xbp0yVy+Y8cOrVixQu+9955GjBhh0db06dN1+fLldPt4+umnVbt27Rw7BgAAkDmGBQL50AcffKCrV6/qs88+s0is0pQrV04DBgwwL5tMJvXt21dfffWVKlSoICcnJ9WqVUvr16+32O7EiRN68803VaFCBTk7O6tIkSJ64YUXdPz4cYt6dw9Rc3FxUdWqVfXpp59a1OvRo4fc3NzSxfftt9/KZDJp7dq1FuXbtm1Ty5Yt5enpKRcXF4WEhGjTpk0WdUaPHi2TyaQLFy5YlO/cuVMmk0lz58612H9AQIBFvVOnTsnZ2Vkmkyndcf3888964okn5OrqKnd3d7Vq1UoHDhxIF39m2rZtK0dHRy1evNiifMGCBerQoYPs7e0tyo8ePSrp9pDBu9nb26tIkSJZ3ndWHD9+PNNhhXefC0lq0qRJhnXvPMeS9PHHHyswMFAuLi4W9b799tv7xnT69Gn16tVLfn5+cnR0VOnSpfXGG28oOTn5vkMh74xl79696tGjh3mIrK+vr15++WVdvHjRYn9p189vv/2mDh06yMPDQ0WKFNGAAQN048YNi7pp3zeZSYsv7dytXr1adnZ2GjVqlEW9BQsWyGQy6eOPP77nuWjSpImaNGliUbZjxw7zsd5PkyZN0g07laSJEydm+H/8n//8R1WqVJGjo6P8/PzUp0+fdAn93deAt7e3WrVqpf3791vUy41zda/r4s5j/f7779WqVSvzNVa2bFm9++67SklJSddmYGCgoqOj1aBBAzk7O6t06dKaOXOmRb3k5GSNGjVKtWrVkqenp1xdXfXEE09ozZo1FvXu/H5btmyZxbobN26oUKFCMplMmjhxosW606dP6+WXX5aPj48cHR1VpUoVff755+b1dw5DzuwzevRoSdZd77du3dK7776rsmXLytHRUQEBARoxYoSSkpIs6gUEBJj3Y2dnJ19fX3Xs2FEnT5685/8ZkF/QcwXkQz/88IPKlCmjBg0aZHmbdevWadGiRerfv78cHR31n//8Ry1bttT27dvNv5Dt2LFDmzdvVqdOnVSyZEkdP35cH3/8sZo0aaKDBw/KxcXFos0pU6bI29tbCQkJ+vzzz9W7d28FBASoWbNmVh/T6tWr9fTTT6tWrVqKiIiQnZ2d5syZo6eeekobNmxQ3bp1rW4zI6NGjUr3S4UkzZ8/X927d1doaKjGjx+va9eu6eOPP1ajRo20a9eudElaRlxcXNS2bVt9/fXXeuONNyRJe/bs0YEDB/Tpp59q7969FvX9/f0lSV999ZUaNmyoAgXu/yM7Pj4+XWJpMpmsSsQ6d+6sZ555RpL0008/6euvv860bsWKFfWvf/1LknThwgW99dZbFusXLVqkN998U02aNFG/fv3k6uqqQ4cO6f33379vHGfOnFHdunV1+fJlvfrqq6pYsaJOnz6tb7/9VteuXVPjxo01f/58c/333ntPkszxSDJ/D6xatUp//vmnevbsKV9fX/Ow2AMHDmjr1q3pkpMOHTooICBAkZGR2rp1qz788ENdunRJX3zxxX3jzsxTTz2lN998U5GRkQoLC1PNmjV19uxZ9evXT82aNdPrr79udZtDhw7Ndjz3Mnr0aI0ZM0bNmjXTG2+8ocOHD+vjjz/Wjh07tGnTJhUsWNBcN+0aMAxDR48e1eTJk/XMM8880C/TtjhXJUuWVGRkpEVZRtfz3Llz5ebmpvDwcLm5uWn16tUaNWqUEhISNGHCBIu6ly5d0jPPPKMOHTqoc+fO+uabb/TGG2/IwcFBL7/8siQpISFBn376qXko75UrV/TZZ58pNDRU27dvV/Xq1S3adHJy0pw5cxQWFmYuW7JkSYY/h+Li4lS/fn1zslq0aFH9/PPP6tWrlxISEjRw4EBVqlTJ4vti1qxZOnTokKZMmWIuq1atmkW7WbneX3nlFc2bN0/PP/+8Bg0apG3btikyMlKHDh3S0qVLLdp74okn9Oqrryo1NVX79+/X1KlTdebMGW3YsCHdMQH5jgEgX4mPjzckGW3bts3yNpIMScbOnTvNZSdOnDCcnJyMdu3amcuuXbuWbtstW7YYkowvvvjCXDZnzhxDknHs2DFz2e+//25IMj744ANzWffu3Q1XV9d0bS5evNiQZKxZs8YwDMNITU01ypcvb4SGhhqpqakW8ZQuXdpo3ry5uSwiIsKQZJw/f96izR07dhiSjDlz5ljs39/f37y8f/9+w87Oznj66act4r9y5Yrh5eVl9O7d26LN2NhYw9PTM1353dasWWNIMhYvXmysWLHCMJlMxsmTJw3DMIwhQ4YYZcqUMQzDMEJCQowqVaqYt0tNTTVCQkIMSYaPj4/RuXNnY8aMGcaJEyfS7SPtnGf0cXR0vGd8adL+jyZOnGgumzBhQrr/yzQNGzY0nnzySfPysWPH0p3jzp07G15eXsb169czPB/30q1bN8POzs7YsWNHunV3XgdpQkJCjJCQkAzbyuja/frrrw1Jxvr1681laddPmzZtLOq++eabhiRjz5495jJJRp8+fTKNP6Pvg8TERKNcuXJGlSpVjBs3bhitWrUyPDw8Mvw/vd/x/fTTT4Yko2XLlkZWbud3X19p7v4/PnfunOHg4GC0aNHCSElJMdebPn26Icn4/PPPM43JMAxjxIgRhiTj3Llz5rLcOFdZOVbDyPjaeO211wwXFxfjxo0bFm1KMiZNmmQuS0pKMqpXr24UK1bMSE5ONgzDMG7dumUkJSVZtHfp0iXDx8fHePnll81lad8vnTt3NgoUKGDExsaa1zVt2tTo0qWLIcmYMGGCubxXr15G8eLFjQsXLli036lTJ8PT0zPDY7n759ydsnq9796925BkvPLKKxb1Bg8ebEgyVq9ebS7z9/c3unfvblGvS5cuhouLS4YxAPkNwwKBfCYhIUGS5O7ubtV2wcHBqlWrlnn5scceU9u2bfXLL7+Yh8c4Ozub19+8eVMXL15UuXLl5OXlpZiYmHRtXrp0SRcuXNCff/6pKVOmyN7eXiEhIenqXbhwweJz5coVi/W7d+/WkSNH1KVLF128eNFcLzExUU2bNtX69euVmppqsc3ff/9t0WZ8fPx9z8Hw4cNVs2ZNvfDCCxblq1at0uXLl9W5c2eLNu3t7VWvXr10w33upUWLFipcuLAWLlwowzC0cOFCde7cOcO6JpNJv/zyi/7973+rUKFC+vrrr9WnTx/5+/urY8eOGT5zNWPGDK1
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import metrics\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split, GridSearchCV\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"# 1. Настройка параметров для старых значений\n",
"old_param_grid = {\n",
" 'n_estimators': [50, 100, 200], # Количество деревьев\n",
" 'max_depth': [None, 10, 20, 30], # Максимальная глубина дерева\n",
" 'min_samples_split': [2, 10, 15] # Минимальное количество образцов для разбиения узла\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для старых параметров\n",
"old_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=old_param_grid, scoring='neg_mean_squared_error', cv=5, n_jobs=-1, verbose=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"old_grid_search.fit(X_train, y_train)\n",
"\n",
"# 2. Результаты подбора для старых параметров\n",
"old_best_params = old_grid_search.best_params_\n",
"old_best_mse = -old_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"# 3. Настройка параметров для новых значений\n",
"new_param_grid = {\n",
" 'n_estimators': [200],\n",
" 'max_depth': [10],\n",
" 'min_samples_split': [10]\n",
"}\n",
"\n",
"# Подбор гиперпараметров с помощью Grid Search для новых параметров\n",
"new_grid_search = GridSearchCV(estimator=RandomForestRegressor(), \n",
" param_grid=new_param_grid, scoring='neg_mean_squared_error', cv=2)\n",
"\n",
"# Обучение модели на тренировочных данных\n",
"new_grid_search.fit(X_train, y_train)\n",
"\n",
"# 4. Результаты подбора для новых параметров\n",
"new_best_params = new_grid_search.best_params_\n",
"new_best_mse = -new_grid_search.best_score_ # Меняем знак, так как берем отрицательное значение MSE\n",
"\n",
"# 5. Обучение модели с лучшими параметрами для новых значений\n",
"model_best = RandomForestRegressor(**new_best_params)\n",
"model_best.fit(X_train, y_train)\n",
"\n",
"# Прогнозирование на тестовой выборке\n",
"y_pred = model_best.predict(X_test)\n",
"\n",
"# Оценка производительности модели\n",
"mse = metrics.mean_squared_error(y_test, y_pred)\n",
"rmse = np.sqrt(mse)\n",
"\n",
"# Вывод результатов\n",
"print(\"Старые параметры:\", old_best_params)\n",
"print(\"Лучший результат (MSE) на старых параметрах:\", old_best_mse)\n",
"print(\"\\nНовые параметры:\", new_best_params)\n",
"print(\"Лучший результат (MSE) на новых параметрах:\", new_best_mse)\n",
"print(\"Среднеквадратическая ошибка (MSE) на тестовых данных:\", mse)\n",
"print(\"Корень среднеквадратичной ошибки (RMSE) на тестовых данных:\", rmse)\n",
"\n",
"# Визуализация ошибок\n",
"plt.figure(figsize=(10, 5))\n",
"plt.bar(['Старые параметры', 'Новые параметры'], [old_best_mse, new_best_mse], color=['blue', 'orange'])\n",
"plt.xlabel('Подбор параметров')\n",
"plt.ylabel('Среднеквадратическая ошибка (MSE)')\n",
"plt.title('Сравнение MSE для старых и новых параметров')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Сравнивая результаты старых и новых параметров, можно сказать, что старые параметры модели позволили добиться меньшей среднеквадратической ошибки, что указывает на более эффективное предсказание по сравнению с новыми настройками. Скорее всего модель обучена достаточно хорошо, учитывая следующие факторы:\n",
"1. Показатели MSE на тренировочных (0.159) и тестовых данных (0.1589) очень близки. Это говорит о том, что модель не переобучена и не недообучена — она хорошо обобщает на тестовой выборке, что является желаемым результатом. \n",
"2. Старые параметры дали наилучший результат, так что модель способна выдать высокую точность при настройке гиперпараметров. Попытка с новыми параметрами позволила оценить, как модель реагирует на изменения параметров, и выяснить, что увеличение max_depth и снижение min_samples_split улучшили результат. Этот процесс настройки параметров — часть процесса улучшения модели. \n",
"3. Старые параметры дали наилучший результат, так что модель способна выдать высокую точность при настройке гиперпараметров. Попытка с новыми параметрами позволила оценить, как модель реагирует на изменения параметров, и выяснить, что увеличение max_depth и снижение min_samples_split улучшили результат. Этот процесс настройки параметров — часть процесса улучшения модели. \n",
"\n",
"Таким образом, можно сказать, что модель обучена хорошо, но возможны дальнейшие небольшие улучшения за счет оптимизации гиперпараметров."
]
},
{
"cell_type": "code",
"execution_count": 176,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHWCAYAAABACtmGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9d3gc1dX3b7Y3adV7tVwkd2yMCxhML6bYFFMSwPSEFiAhvPAmdEIIJBBDIPFHAiQ4gdBMC90QwAY3bLDV3CSrl1Xb1RZtm++PM0d3drVqBmJ4M+d59pF2dubOLeeefs6VZFmWoYEGGmiggQYaaKCBBhpooMGwoDvYHdBAAw000EADDTTQQAMNNPiug6Y4aaCBBhpooIEGGmiggQYajAKa4qSBBhpooIEGGmiggQYaaDAKaIqTBhpooIEGGmiggQYaaKDBKKApThpooIEGGmiggQYaaKCBBqOApjhpoIEGGmiggQYaaKCBBhqMApripIEGGmiggQYaaKCBBhpoMApoipMGGmiggQYaaKCBBhpooMEooClOGmiggQYaaKCBBhpooIEGo4CmOGmggQYaaKDBfwA+/fRTfPTRR4PfP/roI6xfv37Mz0ejUUyfPh333Xfft9A7DRJBSUkJVq5cebC7oYEGXxu6urpgt9vxr3/962B35XsNmuKkQQw8/vjjkCQJ8+fPP9hd0eAgwNNPPw1Jkkb8TJ8+/WB3UwMNvpfQ2NiIq6++Gjt27MCOHTtw9dVXo7GxcczP/+Mf/0BjYyOuvfbawWu8Z7ds2ZLwmSVLlmh79j8EgUAADz/8MObPnw+n0wmLxYLJkyfj2muvxa5duw529w4K3HnnnaPylFNPPfVgd/O/AtLT03H55Zfjl7/85cHuyvcaDAe7Axp8t2DNmjUoKSnBpk2bsGfPHkycOPFgd0mDgwB33303SktLh1zXLN0aaHDgcOaZZ+KRRx7BzJkzAQALFy7EmWeeOebnH3zwQZx33nlwOp3fVhc1OEBwuVw46aSTsHXrVpx66qm44IIL4HA4UFtbi+eeew6rV69GMBg82N08aPDEE0/A4XAMuX7jjTcehN7898KPfvQjrFq1CuvWrcMxxxxzsLvzvQRNcdJgEOrq6rBhwwa8/PLLuOqqq7BmzRrccccdB7tbGhwEOPnkk3HooYcOuf7kk0/C5XIdhB5poMH3H8xmMzZs2ICdO3cCAKZPnw69Xj+mZ7dt24Yvv/wSv/3tb7/NLmpwgLBy5Ups27YNL774Is4666yY3+655x787//+70Hq2XcDzj77bGRkZAy5/otf/OIg9Oa/FyoqKjB9+nQ8/fTTmuJ0gKCF6mkwCGvWrEFqaiqWLl2Ks88+G2vWrBlyT319PSRJwtNPPz14zePxYO7cuSgtLUVra+vgPSN9Vq5ciX379kGSJDz88MND3rNhwwZIkoR//OMfAID9+/fj6quvxpQpU2C1WpGeno5zzjkH9fX1CceyZMmShO9V9xsAXnjhBcydOxdWqxUZGRn44Q9/iObm5ph7EsW4f/TRR5AkKSZfYcmSJViyZEnMfffddx90Oh3+/ve/D1775JNPcM4556CoqAhmsxmFhYW48cYb4ff7E46FYcuWLZAkCc8888yQ39555x1IkoQ33ngDAK3JDTfcgJKSEpjNZmRlZeH444/HF198MeI7DgQkScK1116LNWvWYMqUKbBYLJg7dy4+/vjjIfc2Nzfj0ksvRXZ2NsxmM6ZNm4a//OUvCdsdLsQjfo4BYOPGjTjllFOQmpoKu92OmTNn4ve///3g7ytXrkRJSUnMM88++yx0Oh1+/etfD1776quvsHLlSkyYMAEWiwU5OTm49NJL0dXVFfPsE088gVmzZsHpdMJut2PWrFn485//HHPPWNviccYrpLzeapxduXJlQqvtiy++mBAfRwrRit/LHR0dyMzMxJIlSyDL8uB9e/bsgd1ux7nnnjtsW2Odk/H0f6z7ZKz7EyA8Oemkk+B0OmGz2XDUUUcNyTEa73rE41VjYyOsViskSYqhT9xPvV6PWbNmYdasWXj55ZchSdKQNhLB2rVrYTKZcOSRR45672gQDodxzz33oKysDGazGSUlJbjtttswMDAQc19JScngvtPpdMjJycG5556LhoaGwXsYjx566KFh38dzGg/PPvvsIP1NS0vDeeedN6bQxbHyAw5jXL9+PW666SZkZmbCbrdj+fLl6OzsjLlXlmXce++9KCgogM1mw9FHH43KyspR+wIQXr355pu47LLLhihNACnMPD8rV64clT/yOF599VUsXboUeXl5MJvNKCsrwz333INIJBLTPu/1rVu3YtGiRbBarSgtLcUf//jHmPt4X7z44ovDjiURTkejUTzyyCOYNm0aLBYLsrOzcdVVV6Gnp2dM8zNeeOihh7Bo0SKkp6fDarVi7ty5Cfs8Vt4zXnwxmUxD8OOzzz4bXJ/4sNjR6MpYwhWZVn3bawkAxx9/PF5//fUYOq/B2EHzOGkwCGvWrMGZZ54Jk8mE888/H0888QQ2b96MefPmDftMKBTCWWedhYaGBqxfvx65ubnwer3429/+NnjPyy+/jFdeeSXmWllZGSZMmIDDDz8ca9asGeKuX7NmDZKSknDGGWcAADZv3owNGzbgvPPOQ0FBAerr6/HEE09gyZIlqKqqgs1mG9K38vLyQSufy+Ua8o6nn34al1xyCebNm4f7778f7e3t+P3vf4/169dj27ZtSElJGfccquGpp57CL37xC/z2t7/FBRdcMHj9hRdegM/nw49//GOkp6dj06ZNePTRR9HU1IQXXnhh2PYOPfRQTJgwAf/85z9x8cUXx/z2/PPPIzU1FSeeeCIAcse/+OKLuPbaazF16lR0dXXh008/RXV1NebMmfO1xpUI/v3vf+P555/H9ddfD7PZjMcffxwnnXQSNm3aNCi8t7e3Y8GCBYPMLjMzE2+99RYuu+wyuN1u3HDDDQnbVod43HrrrUN+f++993DqqaciNzcXP/nJT5CTk4Pq6mq88cYb+MlPfpKwzXfffReXXnoprr32WvzP//xPTFv79u3DJZdcgpycHFRWVmL16tWorKzE559/Pij8eTwenHDCCSgrK4Msy/jnP/+Jyy+/HCkpKYOC01jb+q5AVlYWnnjiCZxzzjl49NFHcf311yMajWLlypVISkrC448/PuLzY5mT8cCB7pPhYN26dTj55JMxd+5c3HHHHdDpdHjqqadwzDHH4JNPPsFhhx027jYTwe23345AIDDqfeFweFxeiA0bNmD69OkwGo0Jf+/r60voDQ6FQkOuXX755XjmmWdw9tln46c//Sk2btyI+++/H9XV1XjllVdi7l28eDGuvPJKRKNR7Ny5E4888ghaWlrwySefjLnvieC+++7DL3/5S6xYsQKXX345Ojs78eijj+LII48clf6Olx9cd911SE1NxR133IH6+no88sgjuPbaa/H8888P3nP77bfj3nvvxSmnnIJTTjkFX3zxBU444YQxhde99tprAIALL7xw1HuvuuoqHHfccYPfL7zwQixfvjwmZDMzMxMA8SiHw4GbbroJDocD69atw+233w63240HH3wwpt2enh6ccsopWLFiBc4//3z885//xI9//GOYTCZceumlo/ZrtD4zv7z++utRV1eHxx57DNu2bcP69euHxckDhd///vc4/fTT8YMf/ADBYBDPPfcczjnnHLzxxhtYunRpzL1j4T3jxRe9Xo9nn302RmZ46qmnYLFYhuztsdCVM888Mybt4cYbb0RFRQWuvPLKwWsVFRWD/3+bawkAc+fOxcMPP4zKykot//FAQNZAA1mWt2zZIgOQ33vvPVmWZTkajcoFBQXyT37yk5j76urqZADyU089JUejUfkHP/iBbLPZ5I0bNw7b9h133CEPh2p/+tOfZABydXX14LVgMChnZGTIF1988eA1n8835NnPPvtMBiD/9a9/HfLb4YcfLh999NEJ+83vyMrKkqdPny77/f7B+9544w0ZgHz77bcPXistLZUvuuiimPY//PBDGYD84YcfDl476qij5KOOOkqWZVl
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.scatter(range(len(y_test)), y_test, label=\"Актуалочка\", color=\"black\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_pred, label=\"Предсказанные(новые параметры)\", color=\"blue\", alpha=0.5)\n",
"plt.scatter(range(len(y_test)), y_test_predict, label=\"Предсказанные(старые параметры)\", color=\"red\", alpha=0.5)\n",
"plt.xlabel(\"Выборка\")\n",
"plt.ylabel(\"Значения\")\n",
"plt.legend()\n",
"plt.title(\"Актуалочка vs Предсказанных значений (Новые and Старые Параметры)\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ураааа! Усёёёё, вроде бы всё ^_^"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "mai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}