IIS_2023_1/istyukov_timofey_lab_6/lab6.py
2024-01-11 15:40:25 +04:00

84 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Использовать нейронную сеть по варианту для ваших данных по варианту, самостоятельно сформулировав задачу.
Интерпретировать результаты и оценить, насколько хорошо она подходит для решения сформулированной вами задачи.
"""
"""
Задача, решаемая нейронной сетью:
Регрессия: Предсказание популярности нового музыкального трека на основе его определённых характеристик.
"""
# 12 вариант
# Набор данных по курсовой: "Prediction of music genre"
# Модель мейронной сети: MLPRegressor
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
DATASET_FILE = 'music_genre.csv'
def main():
df = open_dataset(DATASET_FILE) # берём полный набор данных
print("\033[92m[----------> Набор данных <----------]\033[00m")
print(df)
# Перевод ладов (минор/мажор) в числовые признаки
df_music = df.copy()
df_music['mode'] = df_music['mode'].apply(lambda x: 1 if x == 'Major' else 0)
X = df_music.drop(columns=['popularity']) # характеристики музыкального трека
y = df_music['popularity'] # уровень популярности
# Разделение датасета на тренировочные (99%) и тестовые данные (1%)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.01)
model = MLPRegressor(
# несколько размеров слоёв и узлов
hidden_layer_sizes=(50, 50, 50, 50,),
# функция активации (relu, tanh, identity)
activation='relu',
max_iter=2000
)
model.fit(X_train, y_train)
# Предсказание на тестовых данных
y_pred = model.predict(X_test)
print("\033[92m\n[----------> Оценка модели <----------]\033[00m")
print("Коэффициент детерминации = ",
round(metrics.r2_score(y_test, y_pred), 3))
print("Потери регрессии среднеквадратичной логарифмической ошибки = ",
round(metrics.mean_squared_log_error(y_test, y_pred), 3))
# График для наглядности
sns.regplot(x=y_test, y=y_pred, scatter_kws={'s': 10}, line_kws={'color': 'red'})
plt.xlabel('Реальность')
plt.ylabel('Предсказание')
plt.title('MLPRegressor на примере популярности треков')
plt.savefig("1_plot_result")
plt.show()
# Функция считывания и очищения csv-файла
def open_dataset(csv_file):
# открываем файл с указанием знака-отделителя
df = pd.read_csv(csv_file, delimiter=',')
# выбираем необходимые признаки
df = df[['mode', 'tempo', 'instrumentalness', 'acousticness', 'speechiness', 'danceability',
'energy', 'liveness', 'valence', 'loudness', 'popularity']]
# очищаем набор данных от пустых и неподходящих значений
df = df[df['tempo'] != '?']
df = df.dropna()
return df
if __name__ == "__main__":
main()