This commit is contained in:
gg12 darfren 2024-11-20 19:19:56 +04:00
parent 0e9d03446d
commit f49b209552
7 changed files with 2155 additions and 337 deletions

View File

@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -643,7 +643,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -717,7 +717,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -908,7 +908,7 @@
"[2217 rows x 8 columns]"
]
},
"execution_count": 24,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -950,7 +950,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -993,7 +993,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -1082,7 +1082,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -1121,297 +1121,297 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_98a81_row0_col0, #T_98a81_row0_col1, #T_98a81_row0_col2, #T_98a81_row0_col3, #T_98a81_row2_col2, #T_98a81_row3_col2, #T_98a81_row3_col3 {\n",
"#T_17992_row0_col0, #T_17992_row0_col1, #T_17992_row0_col2, #T_17992_row0_col3, #T_17992_row2_col2, #T_17992_row3_col2, #T_17992_row3_col3 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row0_col4, #T_98a81_row0_col5, #T_98a81_row0_col6, #T_98a81_row0_col7 {\n",
"#T_17992_row0_col4, #T_17992_row0_col5, #T_17992_row0_col6, #T_17992_row0_col7 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col0 {\n",
"#T_17992_row1_col0 {\n",
" background-color: #86d549;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col1 {\n",
"#T_17992_row1_col1 {\n",
" background-color: #7ad151;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col2 {\n",
"#T_17992_row1_col2 {\n",
" background-color: #a5db36;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col3 {\n",
"#T_17992_row1_col3 {\n",
" background-color: #98d83e;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row1_col4 {\n",
"#T_17992_row1_col4 {\n",
" background-color: #d45270;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col5, #T_98a81_row2_col5 {\n",
"#T_17992_row1_col5, #T_17992_row2_col5 {\n",
" background-color: #d04d73;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col6 {\n",
"#T_17992_row1_col6 {\n",
" background-color: #d5546e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row1_col7, #T_98a81_row2_col7 {\n",
"#T_17992_row1_col7, #T_17992_row2_col7 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row2_col0 {\n",
"#T_17992_row2_col0 {\n",
" background-color: #a2da37;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col1 {\n",
"#T_17992_row2_col1 {\n",
" background-color: #6ccd5a;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col3 {\n",
"#T_17992_row2_col3 {\n",
" background-color: #9dd93b;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row2_col4, #T_98a81_row2_col6 {\n",
"#T_17992_row2_col4, #T_17992_row2_col6 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col0 {\n",
"#T_17992_row3_col0 {\n",
" background-color: #1e9b8a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col1 {\n",
"#T_17992_row3_col1 {\n",
" background-color: #23888e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col4 {\n",
"#T_17992_row3_col4 {\n",
" background-color: #b42e8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col5, #T_98a81_row5_col6 {\n",
"#T_17992_row3_col5, #T_17992_row5_col6 {\n",
" background-color: #b52f8c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col6 {\n",
"#T_17992_row3_col6 {\n",
" background-color: #bf3984;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row3_col7, #T_98a81_row4_col6 {\n",
"#T_17992_row3_col7, #T_17992_row4_col6 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col0 {\n",
"#T_17992_row4_col0 {\n",
" background-color: #35b779;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col1 {\n",
"#T_17992_row4_col1 {\n",
" background-color: #1f9f88;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col2 {\n",
"#T_17992_row4_col2 {\n",
" background-color: #81d34d;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row4_col3 {\n",
"#T_17992_row4_col3 {\n",
" background-color: #69cd5b;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row4_col4 {\n",
"#T_17992_row4_col4 {\n",
" background-color: #b83289;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col5, #T_98a81_row5_col7 {\n",
"#T_17992_row4_col5, #T_17992_row5_col7 {\n",
" background-color: #aa2395;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row4_col7 {\n",
"#T_17992_row4_col7 {\n",
" background-color: #b6308b;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col0 {\n",
"#T_17992_row5_col0 {\n",
" background-color: #21a585;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col1 {\n",
"#T_17992_row5_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col2 {\n",
"#T_17992_row5_col2 {\n",
" background-color: #73d056;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row5_col3 {\n",
"#T_17992_row5_col3 {\n",
" background-color: #54c568;\n",
" color: #000000;\n",
"}\n",
"#T_98a81_row5_col4 {\n",
"#T_17992_row5_col4 {\n",
" background-color: #a82296;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row5_col5 {\n",
"#T_17992_row5_col5 {\n",
" background-color: #99159f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col0, #T_98a81_row6_col1, #T_98a81_row7_col2, #T_98a81_row7_col3 {\n",
"#T_17992_row6_col0, #T_17992_row6_col1, #T_17992_row7_col2, #T_17992_row7_col3 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col2 {\n",
"#T_17992_row6_col2 {\n",
" background-color: #22a785;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col3 {\n",
"#T_17992_row6_col3 {\n",
" background-color: #23a983;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col4 {\n",
"#T_17992_row6_col4 {\n",
" background-color: #5c01a6;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col5 {\n",
"#T_17992_row6_col5 {\n",
" background-color: #6c00a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col6 {\n",
"#T_17992_row6_col6 {\n",
" background-color: #7501a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row6_col7 {\n",
"#T_17992_row6_col7 {\n",
" background-color: #8104a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col0 {\n",
"#T_17992_row7_col0 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col1 {\n",
"#T_17992_row7_col1 {\n",
" background-color: #34b679;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_98a81_row7_col4, #T_98a81_row7_col5, #T_98a81_row7_col6, #T_98a81_row7_col7 {\n",
"#T_17992_row7_col4, #T_17992_row7_col5, #T_17992_row7_col6, #T_17992_row7_col7 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_98a81\">\n",
"<table id=\"T_17992\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_98a81_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_98a81_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_98a81_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_98a81_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_98a81_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_98a81_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_98a81_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_98a81_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" <th id=\"T_17992_level0_col0\" class=\"col_heading level0 col0\" >Precision_train</th>\n",
" <th id=\"T_17992_level0_col1\" class=\"col_heading level0 col1\" >Precision_test</th>\n",
" <th id=\"T_17992_level0_col2\" class=\"col_heading level0 col2\" >Recall_train</th>\n",
" <th id=\"T_17992_level0_col3\" class=\"col_heading level0 col3\" >Recall_test</th>\n",
" <th id=\"T_17992_level0_col4\" class=\"col_heading level0 col4\" >Accuracy_train</th>\n",
" <th id=\"T_17992_level0_col5\" class=\"col_heading level0 col5\" >Accuracy_test</th>\n",
" <th id=\"T_17992_level0_col6\" class=\"col_heading level0 col6\" >F1_train</th>\n",
" <th id=\"T_17992_level0_col7\" class=\"col_heading level0 col7\" >F1_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row0\" class=\"row_heading level0 row0\" >gradient_boosting</th>\n",
" <td id=\"T_98a81_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col1\" class=\"data row0 col1\" >0.982609</td>\n",
" <td id=\"T_98a81_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col5\" class=\"data row0 col5\" >0.996396</td>\n",
" <td id=\"T_98a81_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_98a81_row0_col7\" class=\"data row0 col7\" >0.991228</td>\n",
" <th id=\"T_17992_level0_row0\" class=\"row_heading level0 row0\" >gradient_boosting</th>\n",
" <td id=\"T_17992_row0_col0\" class=\"data row0 col0\" >1.000000</td>\n",
" <td id=\"T_17992_row0_col1\" class=\"data row0 col1\" >0.982609</td>\n",
" <td id=\"T_17992_row0_col2\" class=\"data row0 col2\" >1.000000</td>\n",
" <td id=\"T_17992_row0_col3\" class=\"data row0 col3\" >1.000000</td>\n",
" <td id=\"T_17992_row0_col4\" class=\"data row0 col4\" >1.000000</td>\n",
" <td id=\"T_17992_row0_col5\" class=\"data row0 col5\" >0.996396</td>\n",
" <td id=\"T_17992_row0_col6\" class=\"data row0 col6\" >1.000000</td>\n",
" <td id=\"T_17992_row0_col7\" class=\"data row0 col7\" >0.991228</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_98a81_row1_col0\" class=\"data row1 col0\" >0.976035</td>\n",
" <td id=\"T_98a81_row1_col1\" class=\"data row1 col1\" >0.956522</td>\n",
" <td id=\"T_98a81_row1_col2\" class=\"data row1 col2\" >0.993348</td>\n",
" <td id=\"T_98a81_row1_col3\" class=\"data row1 col3\" >0.973451</td>\n",
" <td id=\"T_98a81_row1_col4\" class=\"data row1 col4\" >0.993685</td>\n",
" <td id=\"T_98a81_row1_col5\" class=\"data row1 col5\" >0.985586</td>\n",
" <td id=\"T_98a81_row1_col6\" class=\"data row1 col6\" >0.984615</td>\n",
" <td id=\"T_98a81_row1_col7\" class=\"data row1 col7\" >0.964912</td>\n",
" <th id=\"T_17992_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_17992_row1_col0\" class=\"data row1 col0\" >0.976035</td>\n",
" <td id=\"T_17992_row1_col1\" class=\"data row1 col1\" >0.956522</td>\n",
" <td id=\"T_17992_row1_col2\" class=\"data row1 col2\" >0.993348</td>\n",
" <td id=\"T_17992_row1_col3\" class=\"data row1 col3\" >0.973451</td>\n",
" <td id=\"T_17992_row1_col4\" class=\"data row1 col4\" >0.993685</td>\n",
" <td id=\"T_17992_row1_col5\" class=\"data row1 col5\" >0.985586</td>\n",
" <td id=\"T_17992_row1_col6\" class=\"data row1 col6\" >0.984615</td>\n",
" <td id=\"T_17992_row1_col7\" class=\"data row1 col7\" >0.964912</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_98a81_row2_col0\" class=\"data row2 col0\" >0.995585</td>\n",
" <td id=\"T_98a81_row2_col1\" class=\"data row2 col1\" >0.948718</td>\n",
" <td id=\"T_98a81_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row2_col3\" class=\"data row2 col3\" >0.982301</td>\n",
" <td id=\"T_98a81_row2_col4\" class=\"data row2 col4\" >0.999098</td>\n",
" <td id=\"T_98a81_row2_col5\" class=\"data row2 col5\" >0.985586</td>\n",
" <td id=\"T_98a81_row2_col6\" class=\"data row2 col6\" >0.997788</td>\n",
" <td id=\"T_98a81_row2_col7\" class=\"data row2 col7\" >0.965217</td>\n",
" <th id=\"T_17992_level0_row2\" class=\"row_heading level0 row2\" >random_forest</th>\n",
" <td id=\"T_17992_row2_col0\" class=\"data row2 col0\" >0.995585</td>\n",
" <td id=\"T_17992_row2_col1\" class=\"data row2 col1\" >0.948718</td>\n",
" <td id=\"T_17992_row2_col2\" class=\"data row2 col2\" >1.000000</td>\n",
" <td id=\"T_17992_row2_col3\" class=\"data row2 col3\" >0.982301</td>\n",
" <td id=\"T_17992_row2_col4\" class=\"data row2 col4\" >0.999098</td>\n",
" <td id=\"T_17992_row2_col5\" class=\"data row2 col5\" >0.985586</td>\n",
" <td id=\"T_17992_row2_col6\" class=\"data row2 col6\" >0.997788</td>\n",
" <td id=\"T_17992_row2_col7\" class=\"data row2 col7\" >0.965217</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_98a81_row3_col0\" class=\"data row3 col0\" >0.846154</td>\n",
" <td id=\"T_98a81_row3_col1\" class=\"data row3 col1\" >0.837037</td>\n",
" <td id=\"T_98a81_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_98a81_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_98a81_row3_col4\" class=\"data row3 col4\" >0.963013</td>\n",
" <td id=\"T_98a81_row3_col5\" class=\"data row3 col5\" >0.960360</td>\n",
" <td id=\"T_98a81_row3_col6\" class=\"data row3 col6\" >0.916667</td>\n",
" <td id=\"T_98a81_row3_col7\" class=\"data row3 col7\" >0.911290</td>\n",
" <th id=\"T_17992_level0_row3\" class=\"row_heading level0 row3\" >ridge</th>\n",
" <td id=\"T_17992_row3_col0\" class=\"data row3 col0\" >0.846154</td>\n",
" <td id=\"T_17992_row3_col1\" class=\"data row3 col1\" >0.837037</td>\n",
" <td id=\"T_17992_row3_col2\" class=\"data row3 col2\" >1.000000</td>\n",
" <td id=\"T_17992_row3_col3\" class=\"data row3 col3\" >1.000000</td>\n",
" <td id=\"T_17992_row3_col4\" class=\"data row3 col4\" >0.963013</td>\n",
" <td id=\"T_17992_row3_col5\" class=\"data row3 col5\" >0.960360</td>\n",
" <td id=\"T_17992_row3_col6\" class=\"data row3 col6\" >0.916667</td>\n",
" <td id=\"T_17992_row3_col7\" class=\"data row3 col7\" >0.911290</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_98a81_row4_col0\" class=\"data row4 col0\" >0.903846</td>\n",
" <td id=\"T_98a81_row4_col1\" class=\"data row4 col1\" >0.870690</td>\n",
" <td id=\"T_98a81_row4_col2\" class=\"data row4 col2\" >0.937916</td>\n",
" <td id=\"T_98a81_row4_col3\" class=\"data row4 col3\" >0.893805</td>\n",
" <td id=\"T_98a81_row4_col4\" class=\"data row4 col4\" >0.967073</td>\n",
" <td id=\"T_98a81_row4_col5\" class=\"data row4 col5\" >0.951351</td>\n",
" <td id=\"T_98a81_row4_col6\" class=\"data row4 col6\" >0.920566</td>\n",
" <td id=\"T_98a81_row4_col7\" class=\"data row4 col7\" >0.882096</td>\n",
" <th id=\"T_17992_level0_row4\" class=\"row_heading level0 row4\" >knn</th>\n",
" <td id=\"T_17992_row4_col0\" class=\"data row4 col0\" >0.903846</td>\n",
" <td id=\"T_17992_row4_col1\" class=\"data row4 col1\" >0.870690</td>\n",
" <td id=\"T_17992_row4_col2\" class=\"data row4 col2\" >0.937916</td>\n",
" <td id=\"T_17992_row4_col3\" class=\"data row4 col3\" >0.893805</td>\n",
" <td id=\"T_17992_row4_col4\" class=\"data row4 col4\" >0.967073</td>\n",
" <td id=\"T_17992_row4_col5\" class=\"data row4 col5\" >0.951351</td>\n",
" <td id=\"T_17992_row4_col6\" class=\"data row4 col6\" >0.920566</td>\n",
" <td id=\"T_17992_row4_col7\" class=\"data row4 col7\" >0.882096</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_98a81_row5_col0\" class=\"data row5 col0\" >0.867368</td>\n",
" <td id=\"T_98a81_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_98a81_row5_col2\" class=\"data row5 col2\" >0.913525</td>\n",
" <td id=\"T_98a81_row5_col3\" class=\"data row5 col3\" >0.849558</td>\n",
" <td id=\"T_98a81_row5_col4\" class=\"data row5 col4\" >0.953992</td>\n",
" <td id=\"T_98a81_row5_col5\" class=\"data row5 col5\" >0.938739</td>\n",
" <td id=\"T_98a81_row5_col6\" class=\"data row5 col6\" >0.889849</td>\n",
" <td id=\"T_98a81_row5_col7\" class=\"data row5 col7\" >0.849558</td>\n",
" <th id=\"T_17992_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_17992_row5_col0\" class=\"data row5 col0\" >0.867368</td>\n",
" <td id=\"T_17992_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_17992_row5_col2\" class=\"data row5 col2\" >0.913525</td>\n",
" <td id=\"T_17992_row5_col3\" class=\"data row5 col3\" >0.849558</td>\n",
" <td id=\"T_17992_row5_col4\" class=\"data row5 col4\" >0.953992</td>\n",
" <td id=\"T_17992_row5_col5\" class=\"data row5 col5\" >0.938739</td>\n",
" <td id=\"T_17992_row5_col6\" class=\"data row5 col6\" >0.889849</td>\n",
" <td id=\"T_17992_row5_col7\" class=\"data row5 col7\" >0.849558</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_98a81_row6_col0\" class=\"data row6 col0\" >0.794045</td>\n",
" <td id=\"T_98a81_row6_col1\" class=\"data row6 col1\" >0.824742</td>\n",
" <td id=\"T_98a81_row6_col2\" class=\"data row6 col2\" >0.709534</td>\n",
" <td id=\"T_98a81_row6_col3\" class=\"data row6 col3\" >0.707965</td>\n",
" <td id=\"T_98a81_row6_col4\" class=\"data row6 col4\" >0.903473</td>\n",
" <td id=\"T_98a81_row6_col5\" class=\"data row6 col5\" >0.909910</td>\n",
" <td id=\"T_98a81_row6_col6\" class=\"data row6 col6\" >0.749415</td>\n",
" <td id=\"T_98a81_row6_col7\" class=\"data row6 col7\" >0.761905</td>\n",
" <th id=\"T_17992_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_17992_row6_col0\" class=\"data row6 col0\" >0.794045</td>\n",
" <td id=\"T_17992_row6_col1\" class=\"data row6 col1\" >0.824742</td>\n",
" <td id=\"T_17992_row6_col2\" class=\"data row6 col2\" >0.709534</td>\n",
" <td id=\"T_17992_row6_col3\" class=\"data row6 col3\" >0.707965</td>\n",
" <td id=\"T_17992_row6_col4\" class=\"data row6 col4\" >0.903473</td>\n",
" <td id=\"T_17992_row6_col5\" class=\"data row6 col5\" >0.909910</td>\n",
" <td id=\"T_17992_row6_col6\" class=\"data row6 col6\" >0.749415</td>\n",
" <td id=\"T_17992_row6_col7\" class=\"data row6 col7\" >0.761905</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_98a81_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_98a81_row7_col0\" class=\"data row7 col0\" >0.910112</td>\n",
" <td id=\"T_98a81_row7_col1\" class=\"data row7 col1\" >0.907692</td>\n",
" <td id=\"T_98a81_row7_col2\" class=\"data row7 col2\" >0.538803</td>\n",
" <td id=\"T_98a81_row7_col3\" class=\"data row7 col3\" >0.522124</td>\n",
" <td id=\"T_98a81_row7_col4\" class=\"data row7 col4\" >0.895354</td>\n",
" <td id=\"T_98a81_row7_col5\" class=\"data row7 col5\" >0.891892</td>\n",
" <td id=\"T_98a81_row7_col6\" class=\"data row7 col6\" >0.676880</td>\n",
" <td id=\"T_98a81_row7_col7\" class=\"data row7 col7\" >0.662921</td>\n",
" <th id=\"T_17992_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_17992_row7_col0\" class=\"data row7 col0\" >0.910112</td>\n",
" <td id=\"T_17992_row7_col1\" class=\"data row7 col1\" >0.907692</td>\n",
" <td id=\"T_17992_row7_col2\" class=\"data row7 col2\" >0.538803</td>\n",
" <td id=\"T_17992_row7_col3\" class=\"data row7 col3\" >0.522124</td>\n",
" <td id=\"T_17992_row7_col4\" class=\"data row7 col4\" >0.895354</td>\n",
" <td id=\"T_17992_row7_col5\" class=\"data row7 col5\" >0.891892</td>\n",
" <td id=\"T_17992_row7_col6\" class=\"data row7 col6\" >0.676880</td>\n",
" <td id=\"T_17992_row7_col7\" class=\"data row7 col7\" >0.662921</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x2464b71fd10>"
"<pandas.io.formats.style.Styler at 0x21de793ddc0>"
]
},
"execution_count": 30,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -1458,214 +1458,214 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_fd909_row0_col0, #T_fd909_row2_col0 {\n",
"#T_ea44f_row0_col0, #T_ea44f_row2_col0 {\n",
" background-color: #8bd646;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row0_col1, #T_fd909_row2_col1 {\n",
"#T_ea44f_row0_col1, #T_ea44f_row2_col1 {\n",
" background-color: #90d743;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row0_col2, #T_fd909_row1_col3, #T_fd909_row1_col4 {\n",
"#T_ea44f_row0_col2, #T_ea44f_row1_col3, #T_ea44f_row1_col4 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row0_col3, #T_fd909_row2_col3 {\n",
"#T_ea44f_row0_col3, #T_ea44f_row2_col3 {\n",
" background-color: #d24f71;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row0_col4, #T_fd909_row2_col4 {\n",
"#T_ea44f_row0_col4, #T_ea44f_row2_col4 {\n",
" background-color: #d14e72;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row1_col0, #T_fd909_row1_col1 {\n",
"#T_ea44f_row1_col0, #T_ea44f_row1_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row1_col2 {\n",
"#T_ea44f_row1_col2 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row2_col2 {\n",
"#T_ea44f_row2_col2 {\n",
" background-color: #cf4c74;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col0, #T_fd909_row5_col1 {\n",
"#T_ea44f_row3_col0, #T_ea44f_row5_col1 {\n",
" background-color: #3bbb75;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col1 {\n",
"#T_ea44f_row3_col1 {\n",
" background-color: #50c46a;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row3_col2 {\n",
"#T_ea44f_row3_col2 {\n",
" background-color: #bc3587;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col3 {\n",
"#T_ea44f_row3_col3 {\n",
" background-color: #b32c8e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row3_col4 {\n",
"#T_ea44f_row3_col4 {\n",
" background-color: #b02991;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col0 {\n",
"#T_ea44f_row4_col0 {\n",
" background-color: #4ec36b;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row4_col1 {\n",
"#T_ea44f_row4_col1 {\n",
" background-color: #65cb5e;\n",
" color: #000000;\n",
"}\n",
"#T_fd909_row4_col2 {\n",
"#T_ea44f_row4_col2 {\n",
" background-color: #99159f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col3 {\n",
"#T_ea44f_row4_col3 {\n",
" background-color: #be3885;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row4_col4 {\n",
"#T_ea44f_row4_col4 {\n",
" background-color: #bd3786;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col0 {\n",
"#T_ea44f_row5_col0 {\n",
" background-color: #29af7f;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col2 {\n",
"#T_ea44f_row5_col2 {\n",
" background-color: #9613a1;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col3 {\n",
"#T_ea44f_row5_col3 {\n",
" background-color: #a62098;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row5_col4 {\n",
"#T_ea44f_row5_col4 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col0 {\n",
"#T_ea44f_row6_col0 {\n",
" background-color: #20928c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col1 {\n",
"#T_ea44f_row6_col1 {\n",
" background-color: #1fa088;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col2 {\n",
"#T_ea44f_row6_col2 {\n",
" background-color: #7401a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col3 {\n",
"#T_ea44f_row6_col3 {\n",
" background-color: #7d03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row6_col4 {\n",
"#T_ea44f_row6_col4 {\n",
" background-color: #7201a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row7_col0, #T_fd909_row7_col1 {\n",
"#T_ea44f_row7_col0, #T_ea44f_row7_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_fd909_row7_col2, #T_fd909_row7_col3, #T_fd909_row7_col4 {\n",
"#T_ea44f_row7_col2, #T_ea44f_row7_col3, #T_ea44f_row7_col4 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"</style>\n",
"<table id=\"T_fd909\">\n",
"<table id=\"T_ea44f\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_fd909_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_fd909_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_fd909_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_fd909_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_fd909_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" <th id=\"T_ea44f_level0_col0\" class=\"col_heading level0 col0\" >Accuracy_test</th>\n",
" <th id=\"T_ea44f_level0_col1\" class=\"col_heading level0 col1\" >F1_test</th>\n",
" <th id=\"T_ea44f_level0_col2\" class=\"col_heading level0 col2\" >ROC_AUC_test</th>\n",
" <th id=\"T_ea44f_level0_col3\" class=\"col_heading level0 col3\" >Cohen_kappa_test</th>\n",
" <th id=\"T_ea44f_level0_col4\" class=\"col_heading level0 col4\" >MCC_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_fd909_row0_col0\" class=\"data row0 col0\" >0.985586</td>\n",
" <td id=\"T_fd909_row0_col1\" class=\"data row0 col1\" >0.965217</td>\n",
" <td id=\"T_fd909_row0_col2\" class=\"data row0 col2\" >0.999039</td>\n",
" <td id=\"T_fd909_row0_col3\" class=\"data row0 col3\" >0.956130</td>\n",
" <td id=\"T_fd909_row0_col4\" class=\"data row0 col4\" >0.956360</td>\n",
" <th id=\"T_ea44f_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_ea44f_row0_col0\" class=\"data row0 col0\" >0.985586</td>\n",
" <td id=\"T_ea44f_row0_col1\" class=\"data row0 col1\" >0.965217</td>\n",
" <td id=\"T_ea44f_row0_col2\" class=\"data row0 col2\" >0.999039</td>\n",
" <td id=\"T_ea44f_row0_col3\" class=\"data row0 col3\" >0.956130</td>\n",
" <td id=\"T_ea44f_row0_col4\" class=\"data row0 col4\" >0.956360</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_fd909_row1_col0\" class=\"data row1 col0\" >0.996396</td>\n",
" <td id=\"T_fd909_row1_col1\" class=\"data row1 col1\" >0.991228</td>\n",
" <td id=\"T_fd909_row1_col2\" class=\"data row1 col2\" >0.998118</td>\n",
" <td id=\"T_fd909_row1_col3\" class=\"data row1 col3\" >0.988961</td>\n",
" <td id=\"T_fd909_row1_col4\" class=\"data row1 col4\" >0.989021</td>\n",
" <th id=\"T_ea44f_level0_row1\" class=\"row_heading level0 row1\" >gradient_boosting</th>\n",
" <td id=\"T_ea44f_row1_col0\" class=\"data row1 col0\" >0.996396</td>\n",
" <td id=\"T_ea44f_row1_col1\" class=\"data row1 col1\" >0.991228</td>\n",
" <td id=\"T_ea44f_row1_col2\" class=\"data row1 col2\" >0.998118</td>\n",
" <td id=\"T_ea44f_row1_col3\" class=\"data row1 col3\" >0.988961</td>\n",
" <td id=\"T_ea44f_row1_col4\" class=\"data row1 col4\" >0.989021</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_fd909_row2_col0\" class=\"data row2 col0\" >0.985586</td>\n",
" <td id=\"T_fd909_row2_col1\" class=\"data row2 col1\" >0.964912</td>\n",
" <td id=\"T_fd909_row2_col2\" class=\"data row2 col2\" >0.995745</td>\n",
" <td id=\"T_fd909_row2_col3\" class=\"data row2 col3\" >0.955843</td>\n",
" <td id=\"T_fd909_row2_col4\" class=\"data row2 col4\" >0.955901</td>\n",
" <th id=\"T_ea44f_level0_row2\" class=\"row_heading level0 row2\" >decision_tree</th>\n",
" <td id=\"T_ea44f_row2_col0\" class=\"data row2 col0\" >0.985586</td>\n",
" <td id=\"T_ea44f_row2_col1\" class=\"data row2 col1\" >0.964912</td>\n",
" <td id=\"T_ea44f_row2_col2\" class=\"data row2 col2\" >0.995745</td>\n",
" <td id=\"T_ea44f_row2_col3\" class=\"data row2 col3\" >0.955843</td>\n",
" <td id=\"T_ea44f_row2_col4\" class=\"data row2 col4\" >0.955901</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_fd909_row3_col0\" class=\"data row3 col0\" >0.951351</td>\n",
" <td id=\"T_fd909_row3_col1\" class=\"data row3 col1\" >0.882096</td>\n",
" <td id=\"T_fd909_row3_col2\" class=\"data row3 col2\" >0.990049</td>\n",
" <td id=\"T_fd909_row3_col3\" class=\"data row3 col3\" >0.851456</td>\n",
" <td id=\"T_fd909_row3_col4\" class=\"data row3 col4\" >0.851572</td>\n",
" <th id=\"T_ea44f_level0_row3\" class=\"row_heading level0 row3\" >knn</th>\n",
" <td id=\"T_ea44f_row3_col0\" class=\"data row3 col0\" >0.951351</td>\n",
" <td id=\"T_ea44f_row3_col1\" class=\"data row3 col1\" >0.882096</td>\n",
" <td id=\"T_ea44f_row3_col2\" class=\"data row3 col2\" >0.990049</td>\n",
" <td id=\"T_ea44f_row3_col3\" class=\"data row3 col3\" >0.851456</td>\n",
" <td id=\"T_ea44f_row3_col4\" class=\"data row3 col4\" >0.851572</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_fd909_row4_col0\" class=\"data row4 col0\" >0.960360</td>\n",
" <td id=\"T_fd909_row4_col1\" class=\"data row4 col1\" >0.911290</td>\n",
" <td id=\"T_fd909_row4_col2\" class=\"data row4 col2\" >0.982001</td>\n",
" <td id=\"T_fd909_row4_col3\" class=\"data row4 col3\" >0.886026</td>\n",
" <td id=\"T_fd909_row4_col4\" class=\"data row4 col4\" >0.891838</td>\n",
" <th id=\"T_ea44f_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_ea44f_row4_col0\" class=\"data row4 col0\" >0.960360</td>\n",
" <td id=\"T_ea44f_row4_col1\" class=\"data row4 col1\" >0.911290</td>\n",
" <td id=\"T_ea44f_row4_col2\" class=\"data row4 col2\" >0.982001</td>\n",
" <td id=\"T_ea44f_row4_col3\" class=\"data row4 col3\" >0.886026</td>\n",
" <td id=\"T_ea44f_row4_col4\" class=\"data row4 col4\" >0.891838</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_fd909_row5_col0\" class=\"data row5 col0\" >0.938739</td>\n",
" <td id=\"T_fd909_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_fd909_row5_col2\" class=\"data row5 col2\" >0.981520</td>\n",
" <td id=\"T_fd909_row5_col3\" class=\"data row5 col3\" >0.811096</td>\n",
" <td id=\"T_fd909_row5_col4\" class=\"data row5 col4\" >0.811096</td>\n",
" <th id=\"T_ea44f_level0_row5\" class=\"row_heading level0 row5\" >logistic</th>\n",
" <td id=\"T_ea44f_row5_col0\" class=\"data row5 col0\" >0.938739</td>\n",
" <td id=\"T_ea44f_row5_col1\" class=\"data row5 col1\" >0.849558</td>\n",
" <td id=\"T_ea44f_row5_col2\" class=\"data row5 col2\" >0.981520</td>\n",
" <td id=\"T_ea44f_row5_col3\" class=\"data row5 col3\" >0.811096</td>\n",
" <td id=\"T_ea44f_row5_col4\" class=\"data row5 col4\" >0.811096</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_fd909_row6_col0\" class=\"data row6 col0\" >0.909910</td>\n",
" <td id=\"T_fd909_row6_col1\" class=\"data row6 col1\" >0.761905</td>\n",
" <td id=\"T_fd909_row6_col2\" class=\"data row6 col2\" >0.974813</td>\n",
" <td id=\"T_fd909_row6_col3\" class=\"data row6 col3\" >0.706746</td>\n",
" <td id=\"T_fd909_row6_col4\" class=\"data row6 col4\" >0.709879</td>\n",
" <th id=\"T_ea44f_level0_row6\" class=\"row_heading level0 row6\" >naive_bayes</th>\n",
" <td id=\"T_ea44f_row6_col0\" class=\"data row6 col0\" >0.909910</td>\n",
" <td id=\"T_ea44f_row6_col1\" class=\"data row6 col1\" >0.761905</td>\n",
" <td id=\"T_ea44f_row6_col2\" class=\"data row6 col2\" >0.974813</td>\n",
" <td id=\"T_ea44f_row6_col3\" class=\"data row6 col3\" >0.706746</td>\n",
" <td id=\"T_ea44f_row6_col4\" class=\"data row6 col4\" >0.709879</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_fd909_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_fd909_row7_col0\" class=\"data row7 col0\" >0.891892</td>\n",
" <td id=\"T_fd909_row7_col1\" class=\"data row7 col1\" >0.662921</td>\n",
" <td id=\"T_fd909_row7_col2\" class=\"data row7 col2\" >0.968086</td>\n",
" <td id=\"T_fd909_row7_col3\" class=\"data row7 col3\" >0.604043</td>\n",
" <td id=\"T_fd909_row7_col4\" class=\"data row7 col4\" >0.636838</td>\n",
" <th id=\"T_ea44f_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_ea44f_row7_col0\" class=\"data row7 col0\" >0.891892</td>\n",
" <td id=\"T_ea44f_row7_col1\" class=\"data row7 col1\" >0.662921</td>\n",
" <td id=\"T_ea44f_row7_col2\" class=\"data row7 col2\" >0.968086</td>\n",
" <td id=\"T_ea44f_row7_col3\" class=\"data row7 col3\" >0.604043</td>\n",
" <td id=\"T_ea44f_row7_col4\" class=\"data row7 col4\" >0.636838</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x24665055160>"
"<pandas.io.formats.style.Styler at 0x21de8b38e60>"
]
},
"execution_count": 31,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -1702,7 +1702,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 11,
"metadata": {},
"outputs": [
{
@ -1730,7 +1730,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -1806,7 +1806,7 @@
"1969 32 1 female 23.65 1 0 southeast 17626.23951"
]
},
"execution_count": 34,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -1838,7 +1838,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -2014,31 +2014,9 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\frenk\\OneDrive\\Рабочий стол\\MII_Salin_Oleg_PIbd-33\\.venv\\Lib\\site-packages\\numpy\\ma\\core.py:2881: RuntimeWarning: invalid value encountered in cast\n",
" _data = np.array(data, dtype=dtype, copy=copy,\n"
]
},
{
"data": {
"text/plain": [
"{'model__criterion': 'entropy',\n",
" 'model__max_depth': 10,\n",
" 'model__max_features': 'log2',\n",
" 'model__n_estimators': 250}"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
@ -2069,7 +2047,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -2111,7 +2089,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -2135,7 +2113,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -2251,7 +2229,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -2354,7 +2332,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": null,
"metadata": {},
"outputs": [
{

View File

@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@ -212,7 +212,7 @@
"[2772 rows x 9 columns]"
]
},
"execution_count": 74,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@ -255,7 +255,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 3,
"metadata": {},
"outputs": [
{
@ -897,7 +897,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -952,7 +952,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -985,7 +985,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -1047,177 +1047,177 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style type=\"text/css\">\n",
"#T_995df_row0_col0, #T_995df_row0_col1 {\n",
"#T_3759d_row0_col0, #T_3759d_row0_col1 {\n",
" background-color: #26818e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row0_col2, #T_995df_row7_col3 {\n",
"#T_3759d_row0_col2, #T_3759d_row7_col3 {\n",
" background-color: #4e02a2;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row0_col3, #T_995df_row7_col2 {\n",
"#T_3759d_row0_col3, #T_3759d_row7_col2 {\n",
" background-color: #da5a6a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col0 {\n",
"#T_3759d_row1_col0 {\n",
" background-color: #25838e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col1 {\n",
"#T_3759d_row1_col1 {\n",
" background-color: #26828e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col2 {\n",
"#T_3759d_row1_col2 {\n",
" background-color: #5102a3;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row1_col3 {\n",
"#T_3759d_row1_col3 {\n",
" background-color: #d9586a;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col0 {\n",
"#T_3759d_row2_col0 {\n",
" background-color: #228b8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col1, #T_995df_row3_col1 {\n",
"#T_3759d_row2_col1, #T_3759d_row3_col1 {\n",
" background-color: #24878e;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col2 {\n",
"#T_3759d_row2_col2 {\n",
" background-color: #6300a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row2_col3, #T_995df_row3_col3 {\n",
"#T_3759d_row2_col3, #T_3759d_row3_col3 {\n",
" background-color: #d7566c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row3_col0 {\n",
"#T_3759d_row3_col0 {\n",
" background-color: #228c8d;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row3_col2 {\n",
"#T_3759d_row3_col2 {\n",
" background-color: #6400a7;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col0, #T_995df_row5_col0 {\n",
"#T_3759d_row4_col0, #T_3759d_row5_col0 {\n",
" background-color: #1f948c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col1, #T_995df_row5_col1 {\n",
"#T_3759d_row4_col1, #T_3759d_row5_col1 {\n",
" background-color: #21918c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col2, #T_995df_row5_col2 {\n",
"#T_3759d_row4_col2, #T_3759d_row5_col2 {\n",
" background-color: #7e03a8;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row4_col3, #T_995df_row5_col3 {\n",
"#T_3759d_row4_col3, #T_3759d_row5_col3 {\n",
" background-color: #d35171;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col0 {\n",
"#T_3759d_row6_col0 {\n",
" background-color: #20a486;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col1 {\n",
"#T_3759d_row6_col1 {\n",
" background-color: #24aa83;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col2 {\n",
"#T_3759d_row6_col2 {\n",
" background-color: #a01a9c;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row6_col3 {\n",
"#T_3759d_row6_col3 {\n",
" background-color: #c13b82;\n",
" color: #f1f1f1;\n",
"}\n",
"#T_995df_row7_col0, #T_995df_row7_col1 {\n",
"#T_3759d_row7_col0, #T_3759d_row7_col1 {\n",
" background-color: #a8db34;\n",
" color: #000000;\n",
"}\n",
"</style>\n",
"<table id=\"T_995df\">\n",
"<table id=\"T_3759d\">\n",
" <thead>\n",
" <tr>\n",
" <th class=\"blank level0\" >&nbsp;</th>\n",
" <th id=\"T_995df_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_995df_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_995df_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_995df_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" <th id=\"T_3759d_level0_col0\" class=\"col_heading level0 col0\" >RMSE_train</th>\n",
" <th id=\"T_3759d_level0_col1\" class=\"col_heading level0 col1\" >RMSE_test</th>\n",
" <th id=\"T_3759d_level0_col2\" class=\"col_heading level0 col2\" >RMAE_test</th>\n",
" <th id=\"T_3759d_level0_col3\" class=\"col_heading level0 col3\" >R2_test</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_995df_row0_col0\" class=\"data row0 col0\" >3221.469707</td>\n",
" <td id=\"T_995df_row0_col1\" class=\"data row0 col1\" >3953.661053</td>\n",
" <td id=\"T_995df_row0_col2\" class=\"data row0 col2\" >45.741609</td>\n",
" <td id=\"T_995df_row0_col3\" class=\"data row0 col3\" >0.901103</td>\n",
" <th id=\"T_3759d_level0_row0\" class=\"row_heading level0 row0\" >random_forest</th>\n",
" <td id=\"T_3759d_row0_col0\" class=\"data row0 col0\" >3221.469707</td>\n",
" <td id=\"T_3759d_row0_col1\" class=\"data row0 col1\" >3953.661053</td>\n",
" <td id=\"T_3759d_row0_col2\" class=\"data row0 col2\" >45.741609</td>\n",
" <td id=\"T_3759d_row0_col3\" class=\"data row0 col3\" >0.901103</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_995df_row1_col0\" class=\"data row1 col0\" >3643.279193</td>\n",
" <td id=\"T_995df_row1_col1\" class=\"data row1 col1\" >4288.040726</td>\n",
" <td id=\"T_995df_row1_col2\" class=\"data row1 col2\" >47.359073</td>\n",
" <td id=\"T_995df_row1_col3\" class=\"data row1 col3\" >0.883668</td>\n",
" <th id=\"T_3759d_level0_row1\" class=\"row_heading level0 row1\" >decision_tree</th>\n",
" <td id=\"T_3759d_row1_col0\" class=\"data row1 col0\" >3643.279193</td>\n",
" <td id=\"T_3759d_row1_col1\" class=\"data row1 col1\" >4288.040726</td>\n",
" <td id=\"T_3759d_row1_col2\" class=\"data row1 col2\" >47.359073</td>\n",
" <td id=\"T_3759d_row1_col3\" class=\"data row1 col3\" >0.883668</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
" <td id=\"T_995df_row2_col0\" class=\"data row2 col0\" >4731.024654</td>\n",
" <td id=\"T_995df_row2_col1\" class=\"data row2 col1\" >4868.817371</td>\n",
" <td id=\"T_995df_row2_col2\" class=\"data row2 col2\" >54.257745</td>\n",
" <td id=\"T_995df_row2_col3\" class=\"data row2 col3\" >0.850021</td>\n",
" <th id=\"T_3759d_level0_row2\" class=\"row_heading level0 row2\" >linear_poly</th>\n",
" <td id=\"T_3759d_row2_col0\" class=\"data row2 col0\" >4731.024654</td>\n",
" <td id=\"T_3759d_row2_col1\" class=\"data row2 col1\" >4868.817371</td>\n",
" <td id=\"T_3759d_row2_col2\" class=\"data row2 col2\" >54.257745</td>\n",
" <td id=\"T_3759d_row2_col3\" class=\"data row2 col3\" >0.850021</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
" <td id=\"T_995df_row3_col0\" class=\"data row3 col0\" >4776.393716</td>\n",
" <td id=\"T_995df_row3_col1\" class=\"data row3 col1\" >4938.699556</td>\n",
" <td id=\"T_995df_row3_col2\" class=\"data row3 col2\" >54.641209</td>\n",
" <td id=\"T_995df_row3_col3\" class=\"data row3 col3\" >0.845685</td>\n",
" <th id=\"T_3759d_level0_row3\" class=\"row_heading level0 row3\" >linear_interact</th>\n",
" <td id=\"T_3759d_row3_col0\" class=\"data row3 col0\" >4776.393716</td>\n",
" <td id=\"T_3759d_row3_col1\" class=\"data row3 col1\" >4938.699556</td>\n",
" <td id=\"T_3759d_row3_col2\" class=\"data row3 col2\" >54.641209</td>\n",
" <td id=\"T_3759d_row3_col3\" class=\"data row3 col3\" >0.845685</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_995df_row4_col0\" class=\"data row4 col0\" >6028.427617</td>\n",
" <td id=\"T_995df_row4_col1\" class=\"data row4 col1\" >6216.544081</td>\n",
" <td id=\"T_995df_row4_col2\" class=\"data row4 col2\" >65.584948</td>\n",
" <td id=\"T_995df_row4_col3\" class=\"data row4 col3\" >0.755499</td>\n",
" <th id=\"T_3759d_level0_row4\" class=\"row_heading level0 row4\" >ridge</th>\n",
" <td id=\"T_3759d_row4_col0\" class=\"data row4 col0\" >6028.427617</td>\n",
" <td id=\"T_3759d_row4_col1\" class=\"data row4 col1\" >6216.544081</td>\n",
" <td id=\"T_3759d_row4_col2\" class=\"data row4 col2\" >65.584948</td>\n",
" <td id=\"T_3759d_row4_col3\" class=\"data row4 col3\" >0.755499</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
" <td id=\"T_995df_row5_col0\" class=\"data row5 col0\" >6028.426993</td>\n",
" <td id=\"T_995df_row5_col1\" class=\"data row5 col1\" >6216.588829</td>\n",
" <td id=\"T_995df_row5_col2\" class=\"data row5 col2\" >65.580879</td>\n",
" <td id=\"T_995df_row5_col3\" class=\"data row5 col3\" >0.755496</td>\n",
" <th id=\"T_3759d_level0_row5\" class=\"row_heading level0 row5\" >linear</th>\n",
" <td id=\"T_3759d_row5_col0\" class=\"data row5 col0\" >6028.426993</td>\n",
" <td id=\"T_3759d_row5_col1\" class=\"data row5 col1\" >6216.588829</td>\n",
" <td id=\"T_3759d_row5_col2\" class=\"data row5 col2\" >65.580879</td>\n",
" <td id=\"T_3759d_row5_col3\" class=\"data row5 col3\" >0.755496</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_995df_row6_col0\" class=\"data row6 col0\" >8230.959070</td>\n",
" <td id=\"T_995df_row6_col1\" class=\"data row6 col1\" >9715.102581</td>\n",
" <td id=\"T_995df_row6_col2\" class=\"data row6 col2\" >81.129201</td>\n",
" <td id=\"T_995df_row6_col3\" class=\"data row6 col3\" >0.402859</td>\n",
" <th id=\"T_3759d_level0_row6\" class=\"row_heading level0 row6\" >knn</th>\n",
" <td id=\"T_3759d_row6_col0\" class=\"data row6 col0\" >8230.959070</td>\n",
" <td id=\"T_3759d_row6_col1\" class=\"data row6 col1\" >9715.102581</td>\n",
" <td id=\"T_3759d_row6_col2\" class=\"data row6 col2\" >81.129201</td>\n",
" <td id=\"T_3759d_row6_col3\" class=\"data row6 col3\" >0.402859</td>\n",
" </tr>\n",
" <tr>\n",
" <th id=\"T_995df_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_995df_row7_col0\" class=\"data row7 col0\" >17848.198895</td>\n",
" <td id=\"T_995df_row7_col1\" class=\"data row7 col1\" >18518.275054</td>\n",
" <td id=\"T_995df_row7_col2\" class=\"data row7 col2\" >116.605174</td>\n",
" <td id=\"T_995df_row7_col3\" class=\"data row7 col3\" >-1.169619</td>\n",
" <th id=\"T_3759d_level0_row7\" class=\"row_heading level0 row7\" >mlp</th>\n",
" <td id=\"T_3759d_row7_col0\" class=\"data row7 col0\" >17848.198895</td>\n",
" <td id=\"T_3759d_row7_col1\" class=\"data row7 col1\" >18518.275054</td>\n",
" <td id=\"T_3759d_row7_col2\" class=\"data row7 col2\" >116.605174</td>\n",
" <td id=\"T_3759d_row7_col3\" class=\"data row7 col3\" >-1.169619</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n"
],
"text/plain": [
"<pandas.io.formats.style.Styler at 0x1d2bcd71160>"
"<pandas.io.formats.style.Styler at 0x203e1a15460>"
]
},
"execution_count": 79,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -1226,7 +1226,7 @@
"reg_metrics = pd.DataFrame.from_dict(models, \"index\")[\n",
" [\"RMSE_train\", \"RMSE_test\", \"RMAE_test\", \"R2_test\"]\n",
"]\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient(\n",
"reg_metrics.sort_values(by=\"RMSE_test\").style.background_gradient( # type: ignore\n",
" cmap=\"viridis\", low=1, high=0.3, subset=[\"RMSE_train\", \"RMSE_test\"]\n",
").background_gradient(cmap=\"plasma\", low=0.3, high=1, subset=[\"RMAE_test\", \"R2_test\"])"
]
@ -1247,7 +1247,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": null,
"metadata": {},
"outputs": [
{
@ -1261,7 +1261,7 @@
}
],
"source": [
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name)\n",
"best_model = str(reg_metrics.sort_values(by=\"RMSE_test\").iloc[0].name) # type: ignore\n",
"\n",
"display(best_model)"
]
@ -1275,7 +1275,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 12,
"metadata": {},
"outputs": [
{
@ -1397,7 +1397,7 @@
"2028 0.0 0.0 13143.86485 13575.291528 "
]
},
"execution_count": 81,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -1426,7 +1426,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 13,
"metadata": {},
"outputs": [
{
@ -1548,7 +1548,7 @@
"2090 0.0 0.0 8930.93455 11318.629065 "
]
},
"execution_count": 82,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}

1471
lec5.ipynb Normal file

File diff suppressed because one or more lines are too long

BIN
requirements.txt Normal file

Binary file not shown.

27
transformers.py Normal file
View File

@ -0,0 +1,27 @@
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
class TitanicFeatures(BaseEstimator, TransformerMixin):
def __init__(self):
pass
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
def get_title(name) -> str:
return name.split(",")[1].split(".")[0].strip()
def get_cabin_type(cabin) -> str:
if pd.isna(cabin):
return "unknown"
return cabin[0]
X["Is_married"] = [1 if get_title(name) == "Mrs" else 0 for name in X["Name"]]
X["Cabin_type"] = [get_cabin_type(cabin) for cabin in X["Cabin"]]
return X
def get_feature_names_out(self, features_in):
return np.append(features_in, ["Is_married", "Cabin_type"], axis=0)

100
utils_clusters.py Normal file
View File

@ -0,0 +1,100 @@
import math
from typing import Dict, List, Tuple
import numpy as np
from pandas import DataFrame
from sklearn import cluster
from sklearn.metrics import silhouette_samples, silhouette_score
def run_agglomerative(
df: DataFrame, num_clusters: int | None = 2
) -> cluster.AgglomerativeClustering:
agglomerative = cluster.AgglomerativeClustering(
n_clusters=num_clusters,
compute_distances=True,
)
return agglomerative.fit(df)
def get_linkage_matrix(model: cluster.AgglomerativeClustering) -> np.ndarray:
counts = np.zeros(model.children_.shape[0]) # type: ignore
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_): # type: ignore
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
return np.column_stack([model.children_, model.distances_, counts]).astype(float)
def print_cluster_result(
df: DataFrame, clusters_num: int, labels: np.ndarray, separator: str = ", "
):
for cluster_id in range(clusters_num):
cluster_indices = np.where(labels == cluster_id)[0]
print(f"Cluster {cluster_id + 1} ({len(cluster_indices)}):")
rules = [str(df.index[idx]) for idx in cluster_indices]
print(separator.join(rules))
print("")
print("--------")
def run_kmeans(
df: DataFrame, num_clusters: int, random_state: int
) -> Tuple[np.ndarray, np.ndarray]:
kmeans = cluster.KMeans(n_clusters=num_clusters, random_state=random_state)
labels = kmeans.fit_predict(df)
return labels, kmeans.cluster_centers_
def fit_kmeans(
reduced_data: np.ndarray, num_clusters: int, random_state: int
) -> cluster.KMeans:
kmeans = cluster.KMeans(n_clusters=num_clusters, random_state=random_state)
kmeans.fit(reduced_data)
return kmeans
def _get_kmeans_range(
df: DataFrame | np.ndarray, random_state: int
) -> Tuple[List, range]:
max_clusters = int(math.sqrt(len(df)))
clusters_range = range(2, max_clusters + 1)
kmeans_per_k = [
cluster.KMeans(n_clusters=k, random_state=random_state).fit(df)
for k in clusters_range
]
return kmeans_per_k, clusters_range
def get_clusters_inertia(df: DataFrame, random_state: int) -> Tuple[List, range]:
kmeans_per_k, clusters_range = _get_kmeans_range(df, random_state)
return [model.inertia_ for model in kmeans_per_k], clusters_range
def get_clusters_silhouette_scores(
df: DataFrame, random_state: int
) -> Tuple[List, range]:
kmeans_per_k, clusters_range = _get_kmeans_range(df, random_state)
return [
float(silhouette_score(df, model.labels_)) for model in kmeans_per_k
], clusters_range
def get_clusters_silhouettes(df: np.ndarray, random_state: int) -> Dict:
kmeans_per_k, _ = _get_kmeans_range(df, random_state)
clusters_silhouettes: Dict = {}
for model in kmeans_per_k:
silhouette_value = silhouette_score(df, model.labels_)
sample_silhouette_values = silhouette_samples(df, model.labels_)
clusters_silhouettes[model.n_clusters] = (
silhouette_value,
sample_silhouette_values,
model,
)
return clusters_silhouettes

242
visual.py Normal file
View File

@ -0,0 +1,242 @@
from typing import Any, Dict, List
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
from scipy.cluster import hierarchy
from sklearn.cluster import KMeans
def draw_data_2d(
df: DataFrame,
col1: int,
col2: int,
y: List | None = None,
classes: List | None = None,
subplot: Any | None = None,
):
ax = None
if subplot is None:
_, ax = plt.subplots()
else:
ax = subplot
scatter = ax.scatter(df[df.columns[col1]], df[df.columns[col2]], c=y)
ax.set(xlabel=df.columns[col1], ylabel=df.columns[col2])
if classes is not None:
ax.legend(
scatter.legend_elements()[0], classes, loc="lower right", title="Classes"
)
def draw_dendrogram(linkage_matrix: np.ndarray):
hierarchy.dendrogram(linkage_matrix, truncate_mode="level", p=3)
def draw_cluster_results(
df: DataFrame,
col1: int,
col2: int,
labels: np.ndarray,
cluster_centers: np.ndarray,
subplot: Any | None = None,
):
ax = None
if subplot is None:
ax = plt
else:
ax = subplot
centroids = cluster_centers
u_labels = np.unique(labels)
for i in u_labels:
ax.scatter(
df[labels == i][df.columns[col1]],
df[labels == i][df.columns[col2]],
label=i,
)
ax.scatter(centroids[:, col1], centroids[:, col2], s=80, color="k")
def draw_clusters(reduced_data: np.ndarray, kmeans: KMeans):
h = 0.02
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(1)
plt.clf()
plt.imshow(
Z,
interpolation="nearest",
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
cmap=plt.cm.Paired, # type: ignore
aspect="auto",
origin="lower",
)
plt.plot(reduced_data[:, 0], reduced_data[:, 1], "k.", markersize=2)
centroids = kmeans.cluster_centers_
plt.scatter(
centroids[:, 0],
centroids[:, 1],
marker="x",
s=169,
linewidths=3,
color="w",
zorder=10,
)
plt.title(
"K-means clustering (PCA-reduced data)\n"
"Centroids are marked with white cross"
)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.xticks(())
plt.yticks(())
def _draw_cluster_scores(
data: List,
clusters_range: range,
score_name: str,
title: str,
):
plt.figure(figsize=(8, 5))
plt.plot(clusters_range, data, "bo-")
plt.xlabel("$k$", fontsize=8)
plt.ylabel(score_name, fontsize=8)
plt.title(title)
def draw_elbow_diagram(inertias: List, clusters_range: range):
_draw_cluster_scores(inertias, clusters_range, "Inertia", "The Elbow Diagram")
def draw_silhouettes_diagram(silhouette: List, clusters_range: range):
_draw_cluster_scores(
silhouette, clusters_range, "Silhouette score", "The Silhouette score"
)
def _draw_silhouette(
ax: Any,
reduced_data: np.ndarray,
n_clusters: int,
silhouette_avg: float,
sample_silhouette_values: List,
cluster_labels: List,
):
ax.set_xlim([-0.1, 1])
ax.set_ylim([0, len(reduced_data) + (n_clusters + 1) * 10])
y_lower = 10
for i in range(n_clusters):
ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]
ith_cluster_silhouette_values.sort()
size_cluster_i = ith_cluster_silhouette_values.shape[0]
y_upper = y_lower + size_cluster_i
color = cm.nipy_spectral(float(i) / n_clusters) # type: ignore
ax.fill_betweenx(
np.arange(y_lower, y_upper),
0,
ith_cluster_silhouette_values,
facecolor=color,
edgecolor=color,
alpha=0.7,
)
ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
y_lower = y_upper + 10 # 10 for the 0 samples
ax.set_title("The silhouette plot for the various clusters.")
ax.set_xlabel("The silhouette coefficient values")
ax.set_ylabel("Cluster label")
ax.axvline(x=silhouette_avg, color="red", linestyle="--")
ax.set_yticks([])
ax.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])
def _draw_cluster_data(
ax: Any,
reduced_data: np.ndarray,
n_clusters: int,
cluster_labels: np.ndarray,
cluster_centers: np.ndarray,
):
colors = cm.nipy_spectral(cluster_labels.astype(float) / n_clusters) # type: ignore
ax.scatter(
reduced_data[:, 0],
reduced_data[:, 1],
marker=".",
s=30,
lw=0,
alpha=0.7,
c=colors,
edgecolor="k",
)
ax.scatter(
cluster_centers[:, 0],
cluster_centers[:, 1],
marker="o",
c="white",
alpha=1,
s=200,
edgecolor="k",
)
for i, c in enumerate(cluster_centers):
ax.scatter(c[0], c[1], marker="$%d$" % i, alpha=1, s=50, edgecolor="k")
ax.set_title("The visualization of the clustered data.")
ax.set_xlabel("Feature space for the 1st feature")
ax.set_ylabel("Feature space for the 2nd feature")
def draw_silhouettes(reduced_data: np.ndarray, silhouettes: Dict):
for key, value in silhouettes.items():
fig, (ax1, ax2) = plt.subplots(1, 2)
fig.set_size_inches(18, 7)
n_clusters = key
silhouette_avg = value[0]
sample_silhouette_values = value[1]
cluster_labels = value[2].labels_
cluster_centers = value[2].cluster_centers_
_draw_silhouette(
ax1,
reduced_data,
n_clusters,
silhouette_avg,
sample_silhouette_values,
cluster_labels,
)
_draw_cluster_data(
ax2,
reduced_data,
n_clusters,
cluster_labels,
cluster_centers,
)
plt.suptitle(
"Silhouette analysis for KMeans clustering on sample data with n_clusters = %d"
% n_clusters,
fontsize=14,
fontweight="bold",
)