IIS_2023_1/simonov_nikita_lab_3/lab3.py
2023-11-29 19:36:57 +04:00

39 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score, classification_report
# Загрузка данных
data = pd.read_csv("train_bikes.csv", sep=',').dropna()
# Подготовка данных
# Здесь определяются пороги для категорий спроса
low_demand_threshold = 100 # Порог для "Низкого спроса"
medium_demand_threshold = 300 # Порог для "Среднего спроса"
# Создание новой категориальной переменной на основе порогов
data['demand_category'] = pd.cut(data['count'], bins=[0, low_demand_threshold, medium_demand_threshold, float('inf')],
labels=["Low Demand", "Medium Demand", "High Demand"])
# Выделение признаков и целевой переменной
X = data.drop(['count', 'demand_category', 'datetime'], axis=1) # Удаление ненужных столбцов
y = data['demand_category']
# Разделение данных на обучающий и тестовый наборы
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Создание и обучение модели дерева решений
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
# Предсказание категорий спроса на тестовом наборе
y_pred = clf.predict(X_test)
# Оценка качества модели
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
# Вывод отчета о классификации
report = classification_report(y_test, y_pred)
print(report)