IIS_2023_1/gordeeva_anna_lab_4/laba4.py
2023-11-17 23:58:52 +04:00

58 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
import streamlit as st
import numpy as np
# Загрузка данных
data = pd.read_csv('data.csv')
# Просматриваем нет ли пустых данных
columns_to_check = ['Жанр', 'Поджанр', 'Количество заказов']
data = data.dropna(subset=columns_to_check)
# Выбираем только строки с жанром животные, так как будем кластеризировать именно в этом жанре
data = data[data['Жанр'] == 'Животные']
# Преобразуем строки в числа
label_encoder = LabelEncoder()
data['Поджанр'] = label_encoder.fit_transform(data['Поджанр'])
X = data[['Количество заказов', 'Поджанр']]
# Масштабирование данных (стандартизация)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Применение DBSCAN
eps = 0.7 # Радиус окрестности
min_samples = 1 # Минимальное количество точек в окрестности
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
labels = dbscan.fit_predict(X_scaled)
# Добавление меток кластеров в исходные данные
data['cluster'] = labels
# Вывод результата
st.write(data[['Количество заказов', 'Поджанр', 'cluster']])
# Обратное преобразование шкалы для первых двух признаков
original_data = scaler.inverse_transform(X_scaled[:, :2])
# Получение списка сопоставления числовых значений поджанров и их текстовых меток
label_mapping = dict(zip(label_encoder.transform(label_encoder.classes_), label_encoder.classes_))
# Преобразование числовых типов данных NumPy в стандартные Python типы
label_mapping_serializable = {str(k): v for k, v in label_mapping.items()}
st.write(label_mapping_serializable)
# Визуализация кластеров с изначальными данными на осях
fig, ax = plt.subplots()
ax.scatter(original_data[:, 1], original_data[:, 0], c=labels, cmap='viridis')
ax.set_title('DBSCAN Clustering')
ax.set_xlabel('Поджанр')
ax.set_ylabel('Количество заказов')
st.pyplot(fig)