IIS_2023_1/istyukov_timofey_lab_3/lab3.py
2024-01-06 00:54:29 +04:00

69 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Решите с помощью библиотечной реализации дерева решений задачу изnлабораторной работы
«Веб-сервис «Дерево решений» по предмету «Методы искусственного интеллекта» на 99% ваших данных.
Проверьте работу модели на оставшемся проценте, сделайте вывод.
"""
"""
Задача, решаемая деревом решений: Классификация музыкальных треков на основе их характеристик,
таких как акустика, танцевальность, инструментальность, темп и т.д.
Дерево решений может предсказывать жанр трека, основываясь на его характеристиках.
"""
# 12 вариант
# Набор данных по курсовой: "Prediction of music genre"
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
DATASET_FILE = 'music_genre.csv'
def main():
df = open_dataset(DATASET_FILE)
df = df.sample(frac=.1) # отбираем 10% рандомных строк с набора данных, т.к. он большой
print("\033[92m[-----> Набор данных <-----]\033[00m")
print(df)
X = df.drop(columns=['music_genre']) # набор числовых признаков
y = df['music_genre'] # набор соответствующих им жанров
# Разделение датасета на тренировочные (99,5%) и тестовые данные (0,5%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.005)
# Создание и обучение дерева решений
model = DecisionTreeClassifier()
model.fit(X_train.values, y_train)
# Прогнозирование жанра на тестовых данных
y_pred = model.predict(X_test.values)
print("\033[92m\n\n\n[-----> Сравнение жанров <-----]\033[00m")
df_result = pd.DataFrame({'Прогноз': y_pred, 'Реальность': y_test})
print(df_result)
score = accuracy_score(y_test, y_pred)
print("\033[92m\n> Оценка точности модели: {}\033[00m" .format(round(score, 2)))
print("\033[92m\n\n\n[-----> Оценки важности признаков <-----]\033[00m")
df_feature = pd.DataFrame({'Признак': X.columns, "Важность": model.feature_importances_})
print(df_feature)
# Функция считывания и очищения csv-файла
def open_dataset(csv_file):
# открываем файл с указанием знака-отделителя
df_genres = pd.read_csv(csv_file, delimiter=',')
# выбираем необходимые признаки
df_genres = df_genres[['tempo', 'instrumentalness', 'acousticness', 'speechiness', 'danceability', 'energy', 'liveness', 'music_genre']]
# очищаем набор данных от пустых и неподходящих значений
df_genres = df_genres[df_genres['tempo'] != '?']
df_genres = df_genres.dropna()
return df_genres
if __name__ == "__main__":
main()