167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
|
import os
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
|
||
|
import matplotlib.pyplot as plt
|
||
|
import seaborn as sns
|
||
|
|
||
|
from sklearn import cluster, mixture
|
||
|
from sklearn.decomposition import PCA
|
||
|
from sklearn.cluster import KMeans, DBSCAN, OPTICS
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
from sklearn.neighbors import kneighbors_graph
|
||
|
from itertools import cycle, islice
|
||
|
|
||
|
import warnings
|
||
|
warnings.simplefilter('ignore')
|
||
|
|
||
|
|
||
|
def generate_clustering_algorithms(Z, n_clusters, m):
|
||
|
# Generate clustering algorithms:
|
||
|
# m = 'MeanShift', 'KMeans', 'MiniBatchKMeans'
|
||
|
|
||
|
# The minimal percentage of similarity of the clustered feature with "Survived" for inclusion in the final dataset
|
||
|
limit_opt = 0.7
|
||
|
|
||
|
params = {'quantile': .2,
|
||
|
'eps': .3,
|
||
|
'damping': .9,
|
||
|
'preference': -200,
|
||
|
'n_neighbors': 10,
|
||
|
'n_clusters': n_clusters,
|
||
|
'min_samples': 3,
|
||
|
'xi': 0.05,
|
||
|
'min_cluster_size': 0.05}
|
||
|
|
||
|
# estimate bandwidth for mean shift
|
||
|
bandwidth = cluster.estimate_bandwidth(Z, quantile=params['quantile'])
|
||
|
|
||
|
# connectivity matrix for structured Ward
|
||
|
connectivity = kneighbors_graph(
|
||
|
Z, n_neighbors=params['n_neighbors'], include_self=False)
|
||
|
|
||
|
# make connectivity symmetric
|
||
|
connectivity = 0.5 * (connectivity + connectivity.T)
|
||
|
|
||
|
# ============
|
||
|
# Create cluster objects
|
||
|
# ============
|
||
|
if m == 'MeanShift':
|
||
|
cl = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
|
||
|
elif m == 'KMeans':
|
||
|
cl = cluster.KMeans(n_clusters=n_clusters, random_state = 1000)
|
||
|
elif m == 'MiniBatchKMeans':
|
||
|
cl = cluster.MiniBatchKMeans(n_clusters=n_clusters)
|
||
|
|
||
|
return cl
|
||
|
|
||
|
|
||
|
def clustering_df(X, n, m, output_hist, title='clusters_by'):
|
||
|
|
||
|
# Standardization
|
||
|
X_columns = X.columns
|
||
|
scaler = StandardScaler()
|
||
|
scaler.fit(X)
|
||
|
X = pd.DataFrame(scaler.transform(X), columns = X_columns)
|
||
|
cl = generate_clustering_algorithms(X, n, m)
|
||
|
cl.fit(X)
|
||
|
if hasattr(cl, 'labels_'):
|
||
|
labels = cl.labels_.astype(np.uint8)
|
||
|
else:
|
||
|
labels = cl.predict(X)
|
||
|
clusters=pd.concat([X, pd.DataFrame({'cluster':labels})], axis=1)
|
||
|
|
||
|
# Inverse Standardization
|
||
|
X_inv = pd.DataFrame(scaler.inverse_transform(X), columns = X_columns)
|
||
|
clusters_inv=pd.concat([X_inv, pd.DataFrame({'cluster':labels})], axis=1)
|
||
|
|
||
|
# Number of points in clusters
|
||
|
print("Number of points in clusters:\n", clusters['cluster'].value_counts())
|
||
|
|
||
|
# Data in clusters - thanks to https://www.kaggle.com/sabanasimbutt/clustering-visualization-of-clusters-using-pca
|
||
|
if output_hist:
|
||
|
for c in clusters:
|
||
|
grid = sns.FacetGrid(clusters_inv, col='cluster')
|
||
|
grid.map(plt.hist, c)
|
||
|
|
||
|
plt.savefig(f'{title}_by_method_{m}.png')
|
||
|
|
||
|
return clusters, clusters_inv
|
||
|
|
||
|
|
||
|
def plot_draw(X, title, m):
|
||
|
# Drawing a plot with clusters on the plane (using PCA transformation)
|
||
|
# Thanks to https://www.kaggle.com/sabanasimbutt/clustering-visualization-of-clusters-using-pca
|
||
|
|
||
|
dist = 1 - cosine_similarity(X)
|
||
|
|
||
|
# PCA transform
|
||
|
pca = PCA(2)
|
||
|
pca.fit(dist)
|
||
|
X_PCA = pca.transform(dist)
|
||
|
|
||
|
# Generate point numbers and colors for clusters
|
||
|
hsv = plt.get_cmap('hsv')
|
||
|
n_clusters = max(X['cluster'].value_counts().index)-min(X['cluster'].value_counts().index)+2
|
||
|
colors = list(hsv(np.linspace(0, 1, n_clusters)))
|
||
|
colors_num = list(np.linspace(min(X['cluster'].value_counts().index), max(X['cluster'].value_counts().index), n_clusters))
|
||
|
colors_num = [int(x) for x in colors_num]
|
||
|
colors_str = [str(x) for x in colors_num]
|
||
|
names_dict = dict(zip(colors_num, colors_str))
|
||
|
colors_dict = dict(zip(colors_num, colors))
|
||
|
|
||
|
# Visualization
|
||
|
x, y = X_PCA[:, 0], X_PCA[:, 1]
|
||
|
|
||
|
df = pd.DataFrame({'x': x, 'y':y, 'label':X['cluster'].tolist()})
|
||
|
groups = df.groupby('label')
|
||
|
|
||
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||
|
|
||
|
for name, group in groups:
|
||
|
ax.plot(group.x, group.y, marker='o', linestyle='', ms=10,
|
||
|
color=colors_dict[name],
|
||
|
label=names_dict[name],
|
||
|
mec='none')
|
||
|
ax.set_aspect('auto')
|
||
|
ax.tick_params(axis='x',which='both',bottom='off',top='off',labelbottom='off')
|
||
|
ax.tick_params(axis= 'y',which='both',left='off',top='off',labelleft='off')
|
||
|
|
||
|
ax.legend(loc='upper right')
|
||
|
ax.set_title(f"{title} by method {m}")
|
||
|
plt.savefig(f'{title}_by_method_{m}.png')
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
data = pd.read_csv("..//heart_disease_uci.csv")
|
||
|
data = data.drop_duplicates().reset_index(drop=True)
|
||
|
|
||
|
print(data.select_dtypes(include='object').columns.tolist())
|
||
|
for column in data.select_dtypes(include='object').columns.tolist():
|
||
|
data[column] = pd.factorize(data[column])[0]
|
||
|
# print(pd.factorize(data[column])[0])
|
||
|
|
||
|
methods_all = ['KMeans', 'MiniBatchKMeans', 'MeanShift']
|
||
|
n_default = 6
|
||
|
|
||
|
data = data[data.notna().all(axis=1)]
|
||
|
|
||
|
res = dict(zip(methods_all, [False]*len(methods_all)))
|
||
|
n_clust = dict(zip(methods_all, [1]*len(methods_all)))
|
||
|
for method in methods_all:
|
||
|
print(f"Method - {method}")
|
||
|
Y, Y_inv = clustering_df(data.copy(), n_default, method, True)
|
||
|
|
||
|
# If the number of clusters is less than 2, then the clustering is not successful
|
||
|
n_cl = len(Y['cluster'].value_counts())
|
||
|
if n_cl > 1:
|
||
|
res[method] = True
|
||
|
n_clust[method] = n_cl
|
||
|
|
||
|
plot_draw(Y, "Data clustering", method)
|
||
|
else:
|
||
|
print('Clustering is not successful because all data is in one cluster!\n')
|
||
|
|