MII_Labs_Mochalov_PI-33/lab11/lab11.py

109 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import random
import numpy as np
import matplotlib.pyplot as plt
def get_distance(first: np.ndarray, second: np.ndarray) -> float:
return math.sqrt(sum([(first[i] - second[i]) ** 2 for i in range(first.shape[0])])) + 1e-5
# Расчёт степени принадлежности
def affiliation_calculation(data: np.ndarray, centers: np.ndarray, k: int, m: int) -> np.ndarray:
data_len = data.shape[0]
u = np.zeros((data_len, k))
for i in range(data_len):
for j in range(k):
total = 0
distance = get_distance(data[i], centers[j])
for c in range(k):
total += (distance / get_distance(data[i], centers[c])) ** (2 / (m - 1))
u[i, j] = 1 / total
return u
# Расчёт отклонения
def variance_calculation(data: np.ndarray, centers: np.ndarray, u: np.ndarray) -> float:
value = 0
for j in range(k):
for i in range(data.shape[0]):
value += get_distance(data[i], centers[j]) ** 2 * u[i, j]
return value
# Обновление центров кластеров
def center_update(data: np.ndarray, u: np.ndarray, k: int, m: int) -> np.ndarray:
centers = np.zeros((k, data.shape[1]))
for j in range(k):
total = 0
for i in range(data.shape[0]):
total += u[i, j] ** m * data[i]
centers[j] = total / np.sum(u[:, j] ** m)
return centers
def fuzzy_c_means(data: np.ndarray, k: int, m: int, max_iter: int = 100, tol: float = 1e-5) -> (
np.ndarray, np.ndarray, float):
centers = np.array([[random.randint(data.min(), data.max()) for i in range(data.shape[1])] for j in range(k)])
u = None
value = 0
for _ in range(max_iter):
u = affiliation_calculation(data, centers, k, m)
new_value = variance_calculation(data, centers, u)
if abs(new_value - value) <= tol:
return centers, u, value
value = new_value
centers = center_update(data, u, k, m)
return centers, u, value
# Работа с plt для визуализации результата
def visualise_resout(centers: np.ndarray, u: np.ndarray):
center_colors = [[random.random(), random.random(), random.random()] for i in range(k)]
point_colors = []
for i in u:
tmp_color = [0, 0, 0]
for j in range(k):
tmp_color[0] += center_colors[j][0] * i[j]
tmp_color[1] += center_colors[j][1] * i[j]
tmp_color[2] += center_colors[j][1] * i[j]
point_colors.append(tmp_color)
plt.title("Нечёткая кластеризация")
plt.xlabel("Размер зарплаты")
# Визуализация
if data.shape[1] == 1:
plt.scatter(data[:, 0], [0] * data.shape[0], c=point_colors)
plt.scatter(centers[:, 0], [0] * centers.shape[0], marker='*', edgecolor='black', s=100, c=center_colors)
plt.gca().axes.get_yaxis().set_visible(False)
else:
plt.scatter(data[:, 0], data[:, 1], c=point_colors)
plt.scatter(centers[:, 0], centers[:, 1], marker='*', edgecolor='black', s=100, c=center_colors)
plt.show()
if __name__ == '__main__':
data: np.ndarray = np.array(
[
[
random.randint(0, 500)
]
for i in range(random.randint(40, 100))
])
k = 3
m = 2
centers, u, value = fuzzy_c_means(data, k, m)
print(f"Значение функции отклонений: {value}")
print("Степени принадлежности первых 10 точек:")
print(*u[:10], sep="\n")
print("Центры всех кластеров:")
print(*centers, sep="\n")
visualise_resout(centers, u)