{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#### Загрузка набора данных" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datepricebedroomsbathroomssqft_livingsqft_lotfloorswaterfrontviewconditiongradesqft_abovesqft_basementyr_builtyr_renovatedzipcodelatlongsqft_living15sqft_lot15
id
712930052020141013T000000221900.031.00118056501.0003711800195509817847.5112-122.25713405650
641410019220141209T000000538000.032.25257072422.000372170400195119919812547.7210-122.31916907639
563150040020150225T000000180000.021.00770100001.000367700193309802847.7379-122.23327208062
248720087520141209T000000604000.043.00196050001.000571050910196509813647.5208-122.39313605000
195440051020150218T000000510000.032.00168080801.0003816800198709807447.6168-122.04518007503
...............................................................
26300001820140521T000000360000.032.50153011313.0003815300200909810347.6993-122.34615301509
660006012020150223T000000400000.042.50231058132.0003823100201409814647.5107-122.36218307200
152330014120140623T000000402101.020.75102013502.0003710200200909814447.5944-122.29910202007
29131010020150116T000000400000.032.50160023882.0003816000200409802747.5345-122.06914101287
152330015720141015T000000325000.020.75102010762.0003710200200809814447.5941-122.29910201357
\n", "

21613 rows × 20 columns

\n", "
" ], "text/plain": [ " date price bedrooms bathrooms sqft_living \\\n", "id \n", "7129300520 20141013T000000 221900.0 3 1.00 1180 \n", "6414100192 20141209T000000 538000.0 3 2.25 2570 \n", "5631500400 20150225T000000 180000.0 2 1.00 770 \n", "2487200875 20141209T000000 604000.0 4 3.00 1960 \n", "1954400510 20150218T000000 510000.0 3 2.00 1680 \n", "... ... ... ... ... ... \n", "263000018 20140521T000000 360000.0 3 2.50 1530 \n", "6600060120 20150223T000000 400000.0 4 2.50 2310 \n", "1523300141 20140623T000000 402101.0 2 0.75 1020 \n", "291310100 20150116T000000 400000.0 3 2.50 1600 \n", "1523300157 20141015T000000 325000.0 2 0.75 1020 \n", "\n", " sqft_lot floors waterfront view condition grade sqft_above \\\n", "id \n", "7129300520 5650 1.0 0 0 3 7 1180 \n", "6414100192 7242 2.0 0 0 3 7 2170 \n", "5631500400 10000 1.0 0 0 3 6 770 \n", "2487200875 5000 1.0 0 0 5 7 1050 \n", "1954400510 8080 1.0 0 0 3 8 1680 \n", "... ... ... ... ... ... ... ... \n", "263000018 1131 3.0 0 0 3 8 1530 \n", "6600060120 5813 2.0 0 0 3 8 2310 \n", "1523300141 1350 2.0 0 0 3 7 1020 \n", "291310100 2388 2.0 0 0 3 8 1600 \n", "1523300157 1076 2.0 0 0 3 7 1020 \n", "\n", " sqft_basement yr_built yr_renovated zipcode lat long \\\n", "id \n", "7129300520 0 1955 0 98178 47.5112 -122.257 \n", "6414100192 400 1951 1991 98125 47.7210 -122.319 \n", "5631500400 0 1933 0 98028 47.7379 -122.233 \n", "2487200875 910 1965 0 98136 47.5208 -122.393 \n", "1954400510 0 1987 0 98074 47.6168 -122.045 \n", "... ... ... ... ... ... ... \n", "263000018 0 2009 0 98103 47.6993 -122.346 \n", "6600060120 0 2014 0 98146 47.5107 -122.362 \n", "1523300141 0 2009 0 98144 47.5944 -122.299 \n", "291310100 0 2004 0 98027 47.5345 -122.069 \n", "1523300157 0 2008 0 98144 47.5941 -122.299 \n", "\n", " sqft_living15 sqft_lot15 \n", "id \n", "7129300520 1340 5650 \n", "6414100192 1690 7639 \n", "5631500400 2720 8062 \n", "2487200875 1360 5000 \n", "1954400510 1800 7503 \n", "... ... ... \n", "263000018 1530 1509 \n", "6600060120 1830 7200 \n", "1523300141 1020 2007 \n", "291310100 1410 1287 \n", "1523300157 1020 1357 \n", "\n", "[21613 rows x 20 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "from sklearn import set_config\n", "\n", "set_config(transform_output=\"pandas\")\n", "\n", "random_state=9\n", "\n", "df = pd.read_csv(\"data/kc_house_data.csv\", index_col=\"id\")\n", "\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Разделение набора данных на обучающую и тестовые выборки (80/20) для задачи классификации\n", "\n", "Целевой признак -- waterfront" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'X_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datepricebedroomsbathroomssqft_livingsqft_lotfloorswaterfrontviewconditiongradesqft_abovesqft_basementyr_builtyr_renovatedzipcodelatlongsqft_living15sqft_lot15
id
304620012520150406T000000202000.021.0074065501.000457400194609816847.4807-122.33210808515
185300003020150416T000000775000.032.503550328072.0003935500198909807747.7292-122.082327035001
182507900520140609T000000739000.042.5028002461142.0003928000199909801447.6586-121.962275060351
252303931520141022T000000481000.032.002580156531.5003925800199009816647.4561-122.36119209840
662340024620140523T000000200000.041.001350115071.0003713500196609805547.4269-122.197132025675
...............................................................
252306913420150406T000000495000.042.502480919111.0024714701010197309802747.4579-121.981254091911
193130041220150416T000000475000.032.25119012003.0003811900200809810347.6542-122.34611801224
433100040020150220T000000252000.031.501150132001.0003711500195609816647.4752-122.345122013066
921290018020140625T000000760000.042.50276060002.000572230530194209811547.6877-122.29516006000
700010077520140721T000000625000.032.001730122191.0004717300198609800447.5825-122.189247013594
\n", "

17290 rows × 20 columns

\n", "
" ], "text/plain": [ " date price bedrooms bathrooms sqft_living \\\n", "id \n", "3046200125 20150406T000000 202000.0 2 1.00 740 \n", "1853000030 20150416T000000 775000.0 3 2.50 3550 \n", "1825079005 20140609T000000 739000.0 4 2.50 2800 \n", "2523039315 20141022T000000 481000.0 3 2.00 2580 \n", "6623400246 20140523T000000 200000.0 4 1.00 1350 \n", "... ... ... ... ... ... \n", "2523069134 20150406T000000 495000.0 4 2.50 2480 \n", "1931300412 20150416T000000 475000.0 3 2.25 1190 \n", "4331000400 20150220T000000 252000.0 3 1.50 1150 \n", "9212900180 20140625T000000 760000.0 4 2.50 2760 \n", "7000100775 20140721T000000 625000.0 3 2.00 1730 \n", "\n", " sqft_lot floors waterfront view condition grade sqft_above \\\n", "id \n", "3046200125 6550 1.0 0 0 4 5 740 \n", "1853000030 32807 2.0 0 0 3 9 3550 \n", "1825079005 246114 2.0 0 0 3 9 2800 \n", "2523039315 15653 1.5 0 0 3 9 2580 \n", "6623400246 11507 1.0 0 0 3 7 1350 \n", "... ... ... ... ... ... ... ... \n", "2523069134 91911 1.0 0 2 4 7 1470 \n", "1931300412 1200 3.0 0 0 3 8 1190 \n", "4331000400 13200 1.0 0 0 3 7 1150 \n", "9212900180 6000 2.0 0 0 5 7 2230 \n", "7000100775 12219 1.0 0 0 4 7 1730 \n", "\n", " sqft_basement yr_built yr_renovated zipcode lat long \\\n", "id \n", "3046200125 0 1946 0 98168 47.4807 -122.332 \n", "1853000030 0 1989 0 98077 47.7292 -122.082 \n", "1825079005 0 1999 0 98014 47.6586 -121.962 \n", "2523039315 0 1990 0 98166 47.4561 -122.361 \n", "6623400246 0 1966 0 98055 47.4269 -122.197 \n", "... ... ... ... ... ... ... \n", "2523069134 1010 1973 0 98027 47.4579 -121.981 \n", "1931300412 0 2008 0 98103 47.6542 -122.346 \n", "4331000400 0 1956 0 98166 47.4752 -122.345 \n", "9212900180 530 1942 0 98115 47.6877 -122.295 \n", "7000100775 0 1986 0 98004 47.5825 -122.189 \n", "\n", " sqft_living15 sqft_lot15 \n", "id \n", "3046200125 1080 8515 \n", "1853000030 3270 35001 \n", "1825079005 2750 60351 \n", "2523039315 1920 9840 \n", "6623400246 1320 25675 \n", "... ... ... \n", "2523069134 2540 91911 \n", "1931300412 1180 1224 \n", "4331000400 1220 13066 \n", "9212900180 1600 6000 \n", "7000100775 2470 13594 \n", "\n", "[17290 rows x 20 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_train'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
waterfront
id
30462001250
18530000300
18250790050
25230393150
66234002460
......
25230691340
19313004120
43310004000
92129001800
70001007750
\n", "

17290 rows × 1 columns

\n", "
" ], "text/plain": [ " waterfront\n", "id \n", "3046200125 0\n", "1853000030 0\n", "1825079005 0\n", "2523039315 0\n", "6623400246 0\n", "... ...\n", "2523069134 0\n", "1931300412 0\n", "4331000400 0\n", "9212900180 0\n", "7000100775 0\n", "\n", "[17290 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'X_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datepricebedroomsbathroomssqft_livingsqft_lotfloorswaterfrontviewconditiongradesqft_abovesqft_basementyr_builtyr_renovatedzipcodelatlongsqft_living15sqft_lot15
id
177595010020150113T000000357823.031.50124091961.0003812400196809807247.7562-122.094169010800
355080004020141114T000000223000.031.0094079801.000369400196109814647.5107-122.34510507980
145460025620141013T000000710000.052.50257096001.002381620950195609812547.7216-122.28226809900
146740009520150224T000000545000.041.752040535781.000571160880195909803847.3844-122.000204053578
62406900320150102T000000829000.042.752970596771.0024816101360197309807547.5953-122.080293042489
...............................................................
350010018920140630T000000300000.021.0096081531.000369600194709815547.7341-122.30011608199
95200149520150306T000000588000.041.75217057501.002371370800197509811647.5668-122.38314505750
607230080020150505T000000595000.041.75251089891.000481680830196409800647.5569-122.17225108931
294401024020140908T000000988000.043.004040197002.00031140400198709805247.7205-122.127393021887
789380267020150424T000000279900.033.25224050002.000391540700198909819847.4114-122.33418007500
\n", "

4323 rows × 20 columns

\n", "
" ], "text/plain": [ " date price bedrooms bathrooms sqft_living \\\n", "id \n", "1775950100 20150113T000000 357823.0 3 1.50 1240 \n", "3550800040 20141114T000000 223000.0 3 1.00 940 \n", "1454600256 20141013T000000 710000.0 5 2.50 2570 \n", "1467400095 20150224T000000 545000.0 4 1.75 2040 \n", "624069003 20150102T000000 829000.0 4 2.75 2970 \n", "... ... ... ... ... ... \n", "3500100189 20140630T000000 300000.0 2 1.00 960 \n", "952001495 20150306T000000 588000.0 4 1.75 2170 \n", "6072300800 20150505T000000 595000.0 4 1.75 2510 \n", "2944010240 20140908T000000 988000.0 4 3.00 4040 \n", "7893802670 20150424T000000 279900.0 3 3.25 2240 \n", "\n", " sqft_lot floors waterfront view condition grade sqft_above \\\n", "id \n", "1775950100 9196 1.0 0 0 3 8 1240 \n", "3550800040 7980 1.0 0 0 3 6 940 \n", "1454600256 9600 1.0 0 2 3 8 1620 \n", "1467400095 53578 1.0 0 0 5 7 1160 \n", "624069003 59677 1.0 0 2 4 8 1610 \n", "... ... ... ... ... ... ... ... \n", "3500100189 8153 1.0 0 0 3 6 960 \n", "952001495 5750 1.0 0 2 3 7 1370 \n", "6072300800 8989 1.0 0 0 4 8 1680 \n", "2944010240 19700 2.0 0 0 3 11 4040 \n", "7893802670 5000 2.0 0 0 3 9 1540 \n", "\n", " sqft_basement yr_built yr_renovated zipcode lat long \\\n", "id \n", "1775950100 0 1968 0 98072 47.7562 -122.094 \n", "3550800040 0 1961 0 98146 47.5107 -122.345 \n", "1454600256 950 1956 0 98125 47.7216 -122.282 \n", "1467400095 880 1959 0 98038 47.3844 -122.000 \n", "624069003 1360 1973 0 98075 47.5953 -122.080 \n", "... ... ... ... ... ... ... \n", "3500100189 0 1947 0 98155 47.7341 -122.300 \n", "952001495 800 1975 0 98116 47.5668 -122.383 \n", "6072300800 830 1964 0 98006 47.5569 -122.172 \n", "2944010240 0 1987 0 98052 47.7205 -122.127 \n", "7893802670 700 1989 0 98198 47.4114 -122.334 \n", "\n", " sqft_living15 sqft_lot15 \n", "id \n", "1775950100 1690 10800 \n", "3550800040 1050 7980 \n", "1454600256 2680 9900 \n", "1467400095 2040 53578 \n", "624069003 2930 42489 \n", "... ... ... \n", "3500100189 1160 8199 \n", "952001495 1450 5750 \n", "6072300800 2510 8931 \n", "2944010240 3930 21887 \n", "7893802670 1800 7500 \n", "\n", "[4323 rows x 20 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'y_test'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
waterfront
id
17759501000
35508000400
14546002560
14674000950
6240690030
......
35001001890
9520014950
60723008000
29440102400
78938026700
\n", "

4323 rows × 1 columns

\n", "
" ], "text/plain": [ " waterfront\n", "id \n", "1775950100 0\n", "3550800040 0\n", "1454600256 0\n", "1467400095 0\n", "624069003 0\n", "... ...\n", "3500100189 0\n", "952001495 0\n", "6072300800 0\n", "2944010240 0\n", "7893802670 0\n", "\n", "[4323 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from utils import split_stratified_into_train_val_test\n", "\n", "X_train, X_val, X_test, y_train, y_val, y_test = split_stratified_into_train_val_test(\n", " df,\n", " stratify_colname=\"waterfront\",\n", " frac_train=0.80,\n", " frac_val=0,\n", " frac_test=0.20,\n", " random_state=random_state,\n", ")\n", "\n", "display(\"X_train\", X_train)\n", "display(\"y_train\", y_train)\n", "\n", "display(\"X_test\", X_test)\n", "display(\"y_test\", y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование конвейера для классификации данных\n", "\n", "preprocessing_num -- конвейер для обработки числовых данных: заполнение пропущенных значений и стандартизация\n", "\n", "preprocessing_cat -- конвейер для обработки категориальных данных: заполнение пропущенных данных и унитарное кодирование\n", "\n", "features_preprocessing -- трансформер для предобработки признаков\n", "\n", "features_engineering -- трансформер для конструирования признаков\n", "\n", "drop_columns -- трансформер для удаления колонок\n", "\n", "features_postprocessing -- трансформер для унитарного кодирования новых признаков\n", "\n", "pipeline_end -- основной конвейер предобработки данных и конструирования признаков\n", "\n", "Конвейер выполняется последовательно.\n", "\n", "Трансформер выполняет параллельно для указанного набора колонок.\n", "\n", "Документация: \n", "\n", "https://scikit-learn.org/1.5/api/sklearn.pipeline.html\n", "\n", "https://scikit-learn.org/1.5/modules/generated/sklearn.compose.ColumnTransformer.html#sklearn.compose.ColumnTransformer" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from sklearn.compose import ColumnTransformer\n", "from sklearn.discriminant_analysis import StandardScaler\n", "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import OneHotEncoder\n", "\n", "from custom_transformers import HouseFeatures\n", "\n", "\n", "columns_to_drop = [\"waterfront\", \"yr_built\", \"zipcode\"]\n", "num_columns = [\n", " column\n", " for column in df.columns\n", " if column not in columns_to_drop and df[column].dtype != \"object\"\n", "]\n", "cat_columns = [\n", " column\n", " for column in df.columns\n", " if column not in columns_to_drop and df[column].dtype == \"object\"\n", "]\n", "\n", "num_imputer = SimpleImputer(strategy=\"median\")\n", "num_scaler = StandardScaler()\n", "preprocessing_num = Pipeline(\n", " [\n", " (\"imputer\", num_imputer),\n", " (\"scaler\", num_scaler),\n", " ]\n", ")\n", "\n", "cat_imputer = SimpleImputer(strategy=\"constant\", fill_value=-1)\n", "cat_encoder = OneHotEncoder(handle_unknown=\"ignore\", sparse_output=False, drop=\"first\")\n", "preprocessing_cat = Pipeline(\n", " [\n", " (\"imputer\", cat_imputer),\n", " (\"encoder\", cat_encoder),\n", " ]\n", ")\n", "\n", "features_preprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"prepocessing_num\", preprocessing_num, num_columns),\n", " (\"prepocessing_cat\", preprocessing_cat, cat_columns),\n", " (\"prepocessing_features\", cat_imputer, [\"yr_built\", \"zipcode\"]),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "features_engineering = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"add_features\", HouseFeatures(), [\"yr_built\", \"zipcode\"]),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "drop_columns = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"drop_columns\", \"drop\", columns_to_drop),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "features_postprocessing = ColumnTransformer(\n", " verbose_feature_names_out=False,\n", " transformers=[\n", " (\"prepocessing_cat\", preprocessing_cat, [\"Region\"]),\n", " ],\n", " remainder=\"passthrough\",\n", ")\n", "\n", "pipeline_end = Pipeline(\n", " [\n", " (\"features_preprocessing\", features_preprocessing),\n", " (\"features_engineering\", features_engineering),\n", " (\"drop_columns\", drop_columns),\n", " (\"features_postprocessing\", features_postprocessing),\n", " ]\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Демонстрация работы конвейера для предобработки данных при классификации" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Region_northHouse_agepricebedroomsbathroomssqft_livingsqft_lotfloorsviewcondition...date_20150506T000000date_20150507T000000date_20150508T000000date_20150509T000000date_20150510T000000date_20150511T000000date_20150512T000000date_20150513T000000date_20150514T000000date_20150515T000000
id
30462001250.078-0.945119-1.468373-1.448400-1.462069-0.205788-0.918509-0.3058830.909775...0.00.00.00.00.00.00.00.00.00.0
18530000301.0350.667867-0.3932860.5033451.6056530.4052880.935992-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
18250790051.0250.5665280.6818000.5033450.7868665.3695560.935992-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
25230393150.034-0.159739-0.393286-0.1472370.5466880.0060650.008742-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
66234002461.058-0.9507490.681800-1.448400-0.796122-0.090424-0.918509-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
..................................................................
25230691341.051-0.1203290.6818000.5033450.4375171.780808-0.9185092.3084110.909775...0.00.00.00.00.00.00.00.00.00.0
19313004121.016-0.176628-0.3932860.178054-0.970797-0.3302982.790494-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
43310004000.068-0.804370-0.393286-0.797819-1.014465-0.051023-0.918509-0.305883-0.628763...0.00.00.00.00.00.00.00.00.00.0
92129001801.0820.6256420.6818000.5033450.743197-0.2185880.935992-0.3058832.448313...0.00.00.00.00.00.00.00.00.00.0
70001007751.0380.245619-0.393286-0.147237-0.381270-0.073854-0.918509-0.3058830.909775...0.00.00.00.00.00.00.00.00.00.0
\n", "

17290 rows × 384 columns

\n", "
" ], "text/plain": [ " Region_north House_age price bedrooms bathrooms \\\n", "id \n", "3046200125 0.0 78 -0.945119 -1.468373 -1.448400 \n", "1853000030 1.0 35 0.667867 -0.393286 0.503345 \n", "1825079005 1.0 25 0.566528 0.681800 0.503345 \n", "2523039315 0.0 34 -0.159739 -0.393286 -0.147237 \n", "6623400246 1.0 58 -0.950749 0.681800 -1.448400 \n", "... ... ... ... ... ... \n", "2523069134 1.0 51 -0.120329 0.681800 0.503345 \n", "1931300412 1.0 16 -0.176628 -0.393286 0.178054 \n", "4331000400 0.0 68 -0.804370 -0.393286 -0.797819 \n", "9212900180 1.0 82 0.625642 0.681800 0.503345 \n", "7000100775 1.0 38 0.245619 -0.393286 -0.147237 \n", "\n", " sqft_living sqft_lot floors view condition ... \\\n", "id ... \n", "3046200125 -1.462069 -0.205788 -0.918509 -0.305883 0.909775 ... \n", "1853000030 1.605653 0.405288 0.935992 -0.305883 -0.628763 ... \n", "1825079005 0.786866 5.369556 0.935992 -0.305883 -0.628763 ... \n", "2523039315 0.546688 0.006065 0.008742 -0.305883 -0.628763 ... \n", "6623400246 -0.796122 -0.090424 -0.918509 -0.305883 -0.628763 ... \n", "... ... ... ... ... ... ... \n", "2523069134 0.437517 1.780808 -0.918509 2.308411 0.909775 ... \n", "1931300412 -0.970797 -0.330298 2.790494 -0.305883 -0.628763 ... \n", "4331000400 -1.014465 -0.051023 -0.918509 -0.305883 -0.628763 ... \n", "9212900180 0.743197 -0.218588 0.935992 -0.305883 2.448313 ... \n", "7000100775 -0.381270 -0.073854 -0.918509 -0.305883 0.909775 ... \n", "\n", " date_20150506T000000 date_20150507T000000 date_20150508T000000 \\\n", "id \n", "3046200125 0.0 0.0 0.0 \n", "1853000030 0.0 0.0 0.0 \n", "1825079005 0.0 0.0 0.0 \n", "2523039315 0.0 0.0 0.0 \n", "6623400246 0.0 0.0 0.0 \n", "... ... ... ... \n", "2523069134 0.0 0.0 0.0 \n", "1931300412 0.0 0.0 0.0 \n", "4331000400 0.0 0.0 0.0 \n", "9212900180 0.0 0.0 0.0 \n", "7000100775 0.0 0.0 0.0 \n", "\n", " date_20150509T000000 date_20150510T000000 date_20150511T000000 \\\n", "id \n", "3046200125 0.0 0.0 0.0 \n", "1853000030 0.0 0.0 0.0 \n", "1825079005 0.0 0.0 0.0 \n", "2523039315 0.0 0.0 0.0 \n", "6623400246 0.0 0.0 0.0 \n", "... ... ... ... \n", "2523069134 0.0 0.0 0.0 \n", "1931300412 0.0 0.0 0.0 \n", "4331000400 0.0 0.0 0.0 \n", "9212900180 0.0 0.0 0.0 \n", "7000100775 0.0 0.0 0.0 \n", "\n", " date_20150512T000000 date_20150513T000000 date_20150514T000000 \\\n", "id \n", "3046200125 0.0 0.0 0.0 \n", "1853000030 0.0 0.0 0.0 \n", "1825079005 0.0 0.0 0.0 \n", "2523039315 0.0 0.0 0.0 \n", "6623400246 0.0 0.0 0.0 \n", "... ... ... ... \n", "2523069134 0.0 0.0 0.0 \n", "1931300412 0.0 0.0 0.0 \n", "4331000400 0.0 0.0 0.0 \n", "9212900180 0.0 0.0 0.0 \n", "7000100775 0.0 0.0 0.0 \n", "\n", " date_20150515T000000 \n", "id \n", "3046200125 0.0 \n", "1853000030 0.0 \n", "1825079005 0.0 \n", "2523039315 0.0 \n", "6623400246 0.0 \n", "... ... \n", "2523069134 0.0 \n", "1931300412 0.0 \n", "4331000400 0.0 \n", "9212900180 0.0 \n", "7000100775 0.0 \n", "\n", "[17290 rows x 384 columns]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessing_result = pipeline_end.fit_transform(X_train)\n", "preprocessed_df = pd.DataFrame(\n", " preprocessing_result,\n", " columns=pipeline_end.get_feature_names_out(),\n", ")\n", "\n", "preprocessed_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Формирование набора моделей для классификации\n", "\n", "logistic -- логистическая регрессия\n", "\n", "ridge -- гребневая регрессия\n", "\n", "decision_tree -- дерево решений\n", "\n", "knn -- k-ближайших соседей\n", "\n", "naive_bayes -- наивный Байесовский классификатор\n", "\n", "gradient_boosting -- метод градиентного бустинга (набор деревьев решений)\n", "\n", "random_forest -- метод случайного леса (набор деревьев решений)\n", "\n", "mlp -- многослойный персептрон (нейронная сеть)\n", "\n", "Документация: https://scikit-learn.org/1.5/supervised_learning.html" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from sklearn import ensemble, linear_model, naive_bayes, neighbors, neural_network, tree\n", "\n", "class_models = {\n", " \"logistic\": {\"model\": linear_model.LogisticRegression()},\n", " # \"ridge\": {\"model\": linear_model.RidgeClassifierCV(cv=5, class_weight=\"balanced\")},\n", " \"ridge\": {\"model\": linear_model.LogisticRegression(penalty=\"l2\", class_weight=\"balanced\")},\n", " \"decision_tree\": {\n", " \"model\": tree.DecisionTreeClassifier(max_depth=7, random_state=random_state)\n", " },\n", " \"knn\": {\"model\": neighbors.KNeighborsClassifier(n_neighbors=7)},\n", " \"naive_bayes\": {\"model\": naive_bayes.GaussianNB()},\n", " \"gradient_boosting\": {\n", " \"model\": ensemble.GradientBoostingClassifier(n_estimators=210)\n", " },\n", " \"random_forest\": {\n", " \"model\": ensemble.RandomForestClassifier(\n", " max_depth=11, class_weight=\"balanced\", random_state=random_state\n", " )\n", " },\n", " \"mlp\": {\n", " \"model\": neural_network.MLPClassifier(\n", " hidden_layer_sizes=(7,),\n", " max_iter=500,\n", " early_stopping=True,\n", " random_state=random_state,\n", " )\n", " },\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Обучение моделей на обучающем наборе данных и оценка на тестовом" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: logistic\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: ridge\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\linear_model\\_logistic.py:469: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " n_iter_i = _check_optimize_result(\n", "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: decision_tree\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: knn\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: naive_bayes\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: gradient_boosting\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: random_forest\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Model: mlp\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] } ], "source": [ "import numpy as np\n", "from sklearn import metrics\n", "\n", "for model_name in class_models.keys():\n", " print(f\"Model: {model_name}\")\n", " model = class_models[model_name][\"model\"]\n", "\n", " model_pipeline = Pipeline([(\"pipeline\", pipeline_end), (\"model\", model)])\n", " model_pipeline = model_pipeline.fit(X_train, y_train.values.ravel())\n", "\n", " y_train_predict = model_pipeline.predict(X_train)\n", " y_test_probs = model_pipeline.predict_proba(X_test)[:, 1]\n", " y_test_predict = np.where(y_test_probs > 0.5, 1, 0)\n", "\n", " class_models[model_name][\"pipeline\"] = model_pipeline\n", " class_models[model_name][\"probs\"] = y_test_probs\n", " class_models[model_name][\"preds\"] = y_test_predict\n", "\n", " class_models[model_name][\"Precision_train\"] = metrics.precision_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Precision_test\"] = metrics.precision_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Recall_train\"] = metrics.recall_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Recall_test\"] = metrics.recall_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Accuracy_train\"] = metrics.accuracy_score(\n", " y_train, y_train_predict\n", " )\n", " class_models[model_name][\"Accuracy_test\"] = metrics.accuracy_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"ROC_AUC_test\"] = metrics.roc_auc_score(\n", " y_test, y_test_probs\n", " )\n", " class_models[model_name][\"F1_train\"] = metrics.f1_score(y_train, y_train_predict)\n", " class_models[model_name][\"F1_test\"] = metrics.f1_score(y_test, y_test_predict)\n", " class_models[model_name][\"MCC_test\"] = metrics.matthews_corrcoef(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(\n", " y_test, y_test_predict\n", " )\n", " class_models[model_name][\"Confusion_matrix\"] = metrics.confusion_matrix(\n", " y_test, y_test_predict\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Сводная таблица оценок качества для использованных моделей классификации\n", "\n", "Документация: https://scikit-learn.org/1.5/modules/model_evaluation.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Матрица неточностей" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import ConfusionMatrixDisplay\n", "import matplotlib.pyplot as plt\n", "\n", "_, ax = plt.subplots(int(len(class_models) / 2), 2, figsize=(12, 10), sharex=False, sharey=False)\n", "for index, key in enumerate(class_models.keys()):\n", " c_matrix = class_models[key][\"Confusion_matrix\"]\n", " disp = ConfusionMatrixDisplay(\n", " confusion_matrix=c_matrix, display_labels=[\"no water\", \"water\"]\n", " ).plot(ax=ax.flat[index])\n", " disp.ax_.set_title(key)\n", "\n", "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.1)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Точность, полнота, верность (аккуратность), F-мера" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
logistic0.8137250.6764710.6384620.6969700.9961830.9951420.7155170.686567
decision_tree0.9343070.6785710.9846150.5757580.9993640.9946800.9588010.622951
gradient_boosting1.0000000.6129031.0000000.5757581.0000000.9939861.0000000.593750
mlp0.7894740.5862070.5769230.5151520.9956620.9935230.6666670.548387
knn0.9500001.0000000.1461540.0606060.9935220.9928290.2533330.114286
random_forest0.3724930.3333331.0000000.8787880.9873340.9856580.5427970.483333
ridge0.3439150.3009711.0000000.9393940.9856560.9828820.5118110.455882
naive_bayes0.0186190.0069161.0000000.3636360.6037020.5965760.0365580.013575
\n" ], "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", " [\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " \"Accuracy_train\",\n", " \"Accuracy_test\",\n", " \"F1_train\",\n", " \"F1_test\",\n", " ]\n", "]\n", "class_metrics.sort_values(\n", " by=\"Accuracy_test\", ascending=False\n", ").style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ROC-кривая, каппа Коэна, коэффициент корреляции Мэтьюса" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
logistic0.9951420.6865670.9960730.6841200.684197
ridge0.9828820.4558820.9954160.4495170.526537
mlp0.9935230.5483870.9944200.5451390.546293
gradient_boosting0.9939860.5937500.9941370.5907230.591016
random_forest0.9856580.4833330.9928800.4775500.536289
knn0.9928290.1142860.8449710.1135120.245298
decision_tree0.9946800.6229510.7861800.6202900.622414
naive_bayes0.5965760.0135750.481002-0.001429-0.006747
\n" ], "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class_metrics = pd.DataFrame.from_dict(class_models, \"index\")[\n", " [\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " \"ROC_AUC_test\",\n", " \"Cohen_kappa_test\",\n", " \"MCC_test\",\n", " ]\n", "]\n", "class_metrics.sort_values(by=\"ROC_AUC_test\", ascending=False).style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\n", " \"ROC_AUC_test\",\n", " \"MCC_test\",\n", " \"Cohen_kappa_test\",\n", " ],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'logistic'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_model = str(class_metrics.sort_values(by=\"MCC_test\", ascending=False).iloc[0].name)\n", "\n", "display(best_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Вывод данных с ошибкой предсказания для оценки" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\ogoro\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\sklearn\\preprocessing\\_encoders.py:242: UserWarning: Found unknown categories in columns [0] during transform. These unknown categories will be encoded as all zeros\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "'Error items count: 21'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datePredictedpricebedroomsbathroomssqft_livingsqft_lotfloorswaterfrontview...gradesqft_abovesqft_basementyr_builtyr_renovatedzipcodelatlongsqft_living15sqft_lot15
id
12103904220150313T0000000425000.032.7536101073861.513...83130480191819629802347.3351-122.362263042126
62406910820140812T00000003200000.043.257000282061.014...1235003500199109807547.5928-122.086491314663
102503908620140916T00000001875000.032.503280291112.013...1132800192509819947.6699-122.416353021074
173280078020150212T00000013065000.053.00415075002.504...113510640190909811947.6303-122.36222504050
212203909420141126T0000000705000.033.001970209782.013...91770200198009807047.3844-122.438228075396
292303924320141113T0000000340000.041.001200118341.013...612000197209807047.4557-122.443167047462
302405901420150325T00000001900000.042.253020114891.513...102110910191619889804047.5395-122.210389011489
322204902420140522T0000001361000.031.00110040461.504...611000192209819847.3440-122.33125507847
342604928420140819T00000002300000.043.254110159292.014...1227201390200109811547.6934-122.271264015929
374160002020140915T0000001540000.032.252100200181.004...81470630194809816647.4544-122.366241017196
376050033620141126T00000012125000.042.753190195132.004...1031900198209803447.6991-122.235275013496
386740017520150224T0000001850000.021.50180041441.004...7900900196209811647.5934-122.39020904173
632900005020150310T0000000641500.011.00100090841.013...710000195009814647.5007-122.38210906536
676270002020141013T00000017700000.068.0012050276002.503...1385703480191019879810247.6298-122.32339408800
727810051520140821T00000001295000.022.502910194492.014...91940970198509817747.7729-122.393254023598
749000004020140718T00000012535000.053.253730106261.004...1037300196309800447.6240-122.221418019110
763120029220140626T0000001669000.021.751950107661.003...61160790195209816647.4504-122.377178011721
763680004120140625T0000000995000.034.504380470442.013...93720660196819909816647.4734-122.365246018512
890750007020150413T00000015350000.055.008000239852.004...1267201280200909800447.6232-122.220460021750
896480089020150109T00000013200000.033.254560133631.004...1127601800199509800447.6205-122.214406013362
920890003720140919T00000016885000.067.759890313742.004...1388601030200109803947.6305-122.240454042730
\n", "

21 rows × 21 columns

\n", "
" ], "text/plain": [ " date Predicted price bedrooms bathrooms \\\n", "id \n", "121039042 20150313T000000 0 425000.0 3 2.75 \n", "624069108 20140812T000000 0 3200000.0 4 3.25 \n", "1025039086 20140916T000000 0 1875000.0 3 2.50 \n", "1732800780 20150212T000000 1 3065000.0 5 3.00 \n", "2122039094 20141126T000000 0 705000.0 3 3.00 \n", "2923039243 20141113T000000 0 340000.0 4 1.00 \n", "3024059014 20150325T000000 0 1900000.0 4 2.25 \n", "3222049024 20140522T000000 1 361000.0 3 1.00 \n", "3426049284 20140819T000000 0 2300000.0 4 3.25 \n", "3741600020 20140915T000000 1 540000.0 3 2.25 \n", "3760500336 20141126T000000 1 2125000.0 4 2.75 \n", "3867400175 20150224T000000 1 850000.0 2 1.50 \n", "6329000050 20150310T000000 0 641500.0 1 1.00 \n", "6762700020 20141013T000000 1 7700000.0 6 8.00 \n", "7278100515 20140821T000000 0 1295000.0 2 2.50 \n", "7490000040 20140718T000000 1 2535000.0 5 3.25 \n", "7631200292 20140626T000000 1 669000.0 2 1.75 \n", "7636800041 20140625T000000 0 995000.0 3 4.50 \n", "8907500070 20150413T000000 1 5350000.0 5 5.00 \n", "8964800890 20150109T000000 1 3200000.0 3 3.25 \n", "9208900037 20140919T000000 1 6885000.0 6 7.75 \n", "\n", " sqft_living sqft_lot floors waterfront view ... grade \\\n", "id ... \n", "121039042 3610 107386 1.5 1 3 ... 8 \n", "624069108 7000 28206 1.0 1 4 ... 12 \n", "1025039086 3280 29111 2.0 1 3 ... 11 \n", "1732800780 4150 7500 2.5 0 4 ... 11 \n", "2122039094 1970 20978 2.0 1 3 ... 9 \n", "2923039243 1200 11834 1.0 1 3 ... 6 \n", "3024059014 3020 11489 1.5 1 3 ... 10 \n", "3222049024 1100 4046 1.5 0 4 ... 6 \n", "3426049284 4110 15929 2.0 1 4 ... 12 \n", "3741600020 2100 20018 1.0 0 4 ... 8 \n", "3760500336 3190 19513 2.0 0 4 ... 10 \n", "3867400175 1800 4144 1.0 0 4 ... 7 \n", "6329000050 1000 9084 1.0 1 3 ... 7 \n", "6762700020 12050 27600 2.5 0 3 ... 13 \n", "7278100515 2910 19449 2.0 1 4 ... 9 \n", "7490000040 3730 10626 1.0 0 4 ... 10 \n", "7631200292 1950 10766 1.0 0 3 ... 6 \n", "7636800041 4380 47044 2.0 1 3 ... 9 \n", "8907500070 8000 23985 2.0 0 4 ... 12 \n", "8964800890 4560 13363 1.0 0 4 ... 11 \n", "9208900037 9890 31374 2.0 0 4 ... 13 \n", "\n", " sqft_above sqft_basement yr_built yr_renovated zipcode \\\n", "id \n", "121039042 3130 480 1918 1962 98023 \n", "624069108 3500 3500 1991 0 98075 \n", "1025039086 3280 0 1925 0 98199 \n", "1732800780 3510 640 1909 0 98119 \n", "2122039094 1770 200 1980 0 98070 \n", "2923039243 1200 0 1972 0 98070 \n", "3024059014 2110 910 1916 1988 98040 \n", "3222049024 1100 0 1922 0 98198 \n", "3426049284 2720 1390 2001 0 98115 \n", "3741600020 1470 630 1948 0 98166 \n", "3760500336 3190 0 1982 0 98034 \n", "3867400175 900 900 1962 0 98116 \n", "6329000050 1000 0 1950 0 98146 \n", "6762700020 8570 3480 1910 1987 98102 \n", "7278100515 1940 970 1985 0 98177 \n", "7490000040 3730 0 1963 0 98004 \n", "7631200292 1160 790 1952 0 98166 \n", "7636800041 3720 660 1968 1990 98166 \n", "8907500070 6720 1280 2009 0 98004 \n", "8964800890 2760 1800 1995 0 98004 \n", "9208900037 8860 1030 2001 0 98039 \n", "\n", " lat long sqft_living15 sqft_lot15 \n", "id \n", "121039042 47.3351 -122.362 2630 42126 \n", "624069108 47.5928 -122.086 4913 14663 \n", "1025039086 47.6699 -122.416 3530 21074 \n", "1732800780 47.6303 -122.362 2250 4050 \n", "2122039094 47.3844 -122.438 2280 75396 \n", "2923039243 47.4557 -122.443 1670 47462 \n", "3024059014 47.5395 -122.210 3890 11489 \n", "3222049024 47.3440 -122.331 2550 7847 \n", "3426049284 47.6934 -122.271 2640 15929 \n", "3741600020 47.4544 -122.366 2410 17196 \n", "3760500336 47.6991 -122.235 2750 13496 \n", "3867400175 47.5934 -122.390 2090 4173 \n", "6329000050 47.5007 -122.382 1090 6536 \n", "6762700020 47.6298 -122.323 3940 8800 \n", "7278100515 47.7729 -122.393 2540 23598 \n", "7490000040 47.6240 -122.221 4180 19110 \n", "7631200292 47.4504 -122.377 1780 11721 \n", "7636800041 47.4734 -122.365 2460 18512 \n", "8907500070 47.6232 -122.220 4600 21750 \n", "8964800890 47.6205 -122.214 4060 13362 \n", "9208900037 47.6305 -122.240 4540 42730 \n", "\n", "[21 rows x 21 columns]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessing_result = pipeline_end.transform(X_test)\n", "preprocessed_df = pd.DataFrame(\n", " preprocessing_result,\n", " columns=pipeline_end.get_feature_names_out(),\n", ")\n", "\n", "y_pred = class_models[best_model][\"preds\"]\n", "\n", "error_index = y_test[y_test[\"waterfront\"] != y_pred].index.tolist()\n", "display(f\"Error items count: {len(error_index)}\")\n", "\n", "error_predicted = pd.Series(y_pred, index=y_test.index).loc[error_index]\n", "error_df = X_test.loc[error_index].copy()\n", "error_df.insert(loc=1, column=\"Predicted\", value=error_predicted)\n", "error_df.sort_index()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Пример использования обученной модели (конвейера) для предсказания" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
datepricebedroomsbathroomssqft_livingsqft_lotfloorswaterfrontviewconditiongradesqft_abovesqft_basementyr_builtyr_renovatedzipcodelatlongsqft_living15sqft_lot15
62406910820140812T0000003200000.043.257000282061.01441235003500199109807547.5928-122.086491314663
\n", "
" ], "text/plain": [ " date price bedrooms bathrooms sqft_living sqft_lot \\\n", "624069108 20140812T000000 3200000.0 4 3.25 7000 28206 \n", "\n", " floors waterfront view condition grade sqft_above sqft_basement \\\n", "624069108 1.0 1 4 4 12 3500 3500 \n", "\n", " yr_built yr_renovated zipcode lat long sqft_living15 \\\n", "624069108 1991 0 98075 47.5928 -122.086 4913 \n", "\n", " sqft_lot15 \n", "624069108 14663 " ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Region_northHouse_agepricebedroomsbathroomssqft_livingsqft_lotfloorsviewcondition...date_20150506T000000date_20150507T000000date_20150508T000000date_20150509T000000date_20150510T000000date_20150511T000000date_20150512T000000date_20150513T000000date_20150514T000000date_20150515T000000
6240691081.033.07.4942060.68181.4792175.3720720.29821-0.9185094.9227040.909775...0.00.00.00.00.00.00.00.00.00.0
\n", "

1 rows × 384 columns

\n", "
" ], "text/plain": [ " Region_north House_age price bedrooms bathrooms \\\n", "624069108 1.0 33.0 7.494206 0.6818 1.479217 \n", "\n", " sqft_living sqft_lot floors view condition ... \\\n", "624069108 5.372072 0.29821 -0.918509 4.922704 0.909775 ... \n", "\n", " date_20150506T000000 date_20150507T000000 date_20150508T000000 \\\n", "624069108 0.0 0.0 0.0 \n", "\n", " date_20150509T000000 date_20150510T000000 date_20150511T000000 \\\n", "624069108 0.0 0.0 0.0 \n", "\n", " date_20150512T000000 date_20150513T000000 date_20150514T000000 \\\n", "624069108 0.0 0.0 0.0 \n", "\n", " date_20150515T000000 \n", "624069108 0.0 \n", "\n", "[1 rows x 384 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'predicted: 0 (proba: [0.8437713 0.1562287])'" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "'real: 1'" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = class_models[best_model][\"pipeline\"]\n", "\n", "example_id = 624069108\n", "test = pd.DataFrame(X_test.loc[example_id, :]).T\n", "test_preprocessed = pd.DataFrame(preprocessed_df.loc[example_id, :]).T\n", "display(test)\n", "display(test_preprocessed)\n", "result_proba = model.predict_proba(test)[0]\n", "result = model.predict(test)[0]\n", "real = int(y_test.loc[example_id].values[0])\n", "display(f\"predicted: {result} (proba: {result_proba})\")\n", "display(f\"real: {real}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Подбор гиперпараметров методом поиска по сетке\n", "\n", "https://www.kaggle.com/code/sociopath00/random-forest-using-gridsearchcv\n", "\n", "https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "optimized_model_type = \"random_forest\"\n", "\n", "random_forest_model = class_models[optimized_model_type][\"pipeline\"]\n", "\n", "param_grid = {\n", " \"model__n_estimators\": [10, 20, 30, 40, 50, 100, 150, 200, 250, 500],\n", " \"model__max_features\": [\"sqrt\", \"log2\", 2],\n", " \"model__max_depth\": [2, 3, 4, 5, 6, 7, 8, 9 ,10],\n", " \"model__criterion\": [\"gini\", \"entropy\", \"log_loss\"],\n", "}\n", "\n", "gs_optomizer = GridSearchCV(\n", " estimator=random_forest_model, param_grid=param_grid, n_jobs=-1\n", ")\n", "gs_optomizer.fit(X_train, y_train.values.ravel())\n", "gs_optomizer.best_params_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Обучение модели с новыми гиперпараметрами" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "optimized_model = ensemble.RandomForestClassifier(\n", " random_state=random_state,\n", " criterion=\"gini\",\n", " max_depth=7,\n", " max_features=\"sqrt\",\n", " n_estimators=30,\n", ")\n", "\n", "result = {}\n", "\n", "result[\"pipeline\"] = Pipeline([(\"pipeline\", pipeline_end), (\"model\", optimized_model)]).fit(X_train, y_train.values.ravel())\n", "result[\"train_preds\"] = result[\"pipeline\"].predict(X_train)\n", "result[\"probs\"] = result[\"pipeline\"].predict_proba(X_test)[:, 1]\n", "result[\"preds\"] = np.where(result[\"probs\"] > 0.5, 1, 0)\n", "\n", "result[\"Precision_train\"] = metrics.precision_score(y_train, result[\"train_preds\"])\n", "result[\"Precision_test\"] = metrics.precision_score(y_test, result[\"preds\"])\n", "result[\"Recall_train\"] = metrics.recall_score(y_train, result[\"train_preds\"])\n", "result[\"Recall_test\"] = metrics.recall_score(y_test, result[\"preds\"])\n", "result[\"Accuracy_train\"] = metrics.accuracy_score(y_train, result[\"train_preds\"])\n", "result[\"Accuracy_test\"] = metrics.accuracy_score(y_test, result[\"preds\"])\n", "result[\"ROC_AUC_test\"] = metrics.roc_auc_score(y_test, result[\"probs\"])\n", "result[\"F1_train\"] = metrics.f1_score(y_train, result[\"train_preds\"])\n", "result[\"F1_test\"] = metrics.f1_score(y_test, result[\"preds\"])\n", "result[\"MCC_test\"] = metrics.matthews_corrcoef(y_test, result[\"preds\"])\n", "result[\"Cohen_kappa_test\"] = metrics.cohen_kappa_score(y_test, result[\"preds\"])\n", "result[\"Confusion_matrix\"] = metrics.confusion_matrix(y_test, result[\"preds\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Формирование данных для оценки старой и новой версии модели" ] }, { "cell_type": "code", "execution_count": 98, "metadata": {}, "outputs": [], "source": [ "optimized_metrics = pd.DataFrame(columns=list(result.keys()))\n", "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", " data=class_models[optimized_model_type]\n", ")\n", "optimized_metrics.loc[len(optimized_metrics)] = pd.Series(\n", " data=result\n", ")\n", "optimized_metrics.insert(loc=0, column=\"Name\", value=[\"Old\", \"New\"])\n", "optimized_metrics = optimized_metrics.set_index(\"Name\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Оценка параметров старой и новой модели" ] }, { "cell_type": "code", "execution_count": 99, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Precision_trainPrecision_testRecall_trainRecall_testAccuracy_trainAccuracy_testF1_trainF1_test
Name        
Old0.8943400.7941180.8681320.7826090.9101120.8379890.8810410.788321
New0.8672200.8225810.7655680.7391300.8651690.8379890.8132300.778626
\n" ], "text/plain": [ "" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optimized_metrics[\n", " [\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " \"Accuracy_train\",\n", " \"Accuracy_test\",\n", " \"F1_train\",\n", " \"F1_test\",\n", " ]\n", "].style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\"Accuracy_train\", \"Accuracy_test\", \"F1_train\", \"F1_test\"],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Precision_train\",\n", " \"Precision_test\",\n", " \"Recall_train\",\n", " \"Recall_test\",\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 100, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
 Accuracy_testF1_testROC_AUC_testCohen_kappa_testMCC_test
Name     
Old0.8379890.7883210.8588930.6571110.657157
New0.8379890.7786260.8597500.6514470.653765
\n" ], "text/plain": [ "" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "optimized_metrics[\n", " [\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " \"ROC_AUC_test\",\n", " \"Cohen_kappa_test\",\n", " \"MCC_test\",\n", " ]\n", "].style.background_gradient(\n", " cmap=\"plasma\",\n", " low=0.3,\n", " high=1,\n", " subset=[\n", " \"ROC_AUC_test\",\n", " \"MCC_test\",\n", " \"Cohen_kappa_test\",\n", " ],\n", ").background_gradient(\n", " cmap=\"viridis\",\n", " low=1,\n", " high=0.3,\n", " subset=[\n", " \"Accuracy_test\",\n", " \"F1_test\",\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, ax = plt.subplots(1, 2, figsize=(10, 4), sharex=False, sharey=False\n", ")\n", "\n", "for index in range(0, len(optimized_metrics)):\n", " c_matrix = optimized_metrics.iloc[index][\"Confusion_matrix\"]\n", " disp = ConfusionMatrixDisplay(\n", " confusion_matrix=c_matrix, display_labels=[\"no water\", \"water\"]\n", " ).plot(ax=ax.flat[index])\n", "\n", "plt.subplots_adjust(top=1, bottom=0, hspace=0.4, wspace=0.3)\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }