IIS_2023_1/ilbekov_dmitriy_lab_1/lab1.py

104 lines
4.7 KiB
Python
Raw Permalink Normal View History

2023-10-15 19:15:47 +04:00
import numpy as np
from sklearn import metrics
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score
from matplotlib import pyplot as plt
#Задание случайного состояния
rs = 35
# Генерации синтетического набора данных в форме двух полумесяцев
# noise - уровень шума данных
# random_state устанавливается в rs для воспроизводимости данных
X, y = make_moons(noise=0.3, random_state=rs)
# test_size какой процент данных пойдет в тестирование
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=rs)
# Подготовка для визуализации
x_minimal, x_maximum = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_minimal, y_maximum = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.arange(x_minimal, x_maximum, 0.02), np.arange(y_minimal, y_maximum, 0.02))
# ЛИНЕЙНАЯ РЕГРЕССИЯ
# Инициализация модели
linear_regression = LinearRegression()
# Обучение
linear_regression.fit(X_train, y_train)
# Предсказание
y_pred_linear_regression = linear_regression.predict(X_test)
# Оценка точности (MSE)
accuracy_linear_regression = metrics.mean_squared_error(y_test, y_pred_linear_regression)
# Предсказание класса для каждой точки в сетке графика и изменение формы результата
Z_linear_regression = linear_regression.predict(np.c_[xx.ravel(), yy.ravel()])
Z_linear_regression = Z_linear_regression.reshape(xx.shape)
# МНОГОСЛОЙНЫЙ ПЕРСЕПТРОН (10)
# Инициализация модели
multi_layer_perceptron_10 = MLPClassifier(hidden_layer_sizes=(10,), alpha=0.01, random_state=rs)
# Обучение
multi_layer_perceptron_10.fit(X_train, y_train)
# Предсказание
y_pred_multi_layer_perceptron_10 = multi_layer_perceptron_10.predict(X_test)
# Оценка точности каждой модели сравнивается с истинными метками классов на тестовой выборке
accuracy_mlp_10 = accuracy_score(y_test, y_pred_multi_layer_perceptron_10)
# Предсказание класса для каждой точки в сетке графика и изменение формы результата
Z_mlp_10 = multi_layer_perceptron_10.predict(np.c_[xx.ravel(), yy.ravel()])
Z_mlp_10 = Z_mlp_10.reshape(xx.shape)
# МНОГОСЛОЙНЫЙ ПЕРСЕПТРОН (100)
# Инициализация модели
multi_layer_perceptron_100 = MLPClassifier(hidden_layer_sizes=(100,), alpha=0.01, random_state=rs)
# Обучение
multi_layer_perceptron_100.fit(X_train, y_train)
# Предсказание
y_pred_multi_layer_perceptron_100 = multi_layer_perceptron_100.predict(X_test)
# Оценка точности (MSE)
accuracy_mlp_100 = accuracy_score(y_test, y_pred_multi_layer_perceptron_100)
# Предсказание класса для каждой точки в сетке графика и изменение формы результата
Z_mlp_100 = multi_layer_perceptron_100.predict(np.c_[xx.ravel(), yy.ravel()])
Z_mlp_100 = Z_mlp_100.reshape(xx.shape)
# ВЫВОД: результаты оценки точности (в консоли) и график
print("Точность: ")
print("LinearRegression:", accuracy_linear_regression)
print("Multi Layer Perceptron 10 нейронов:", accuracy_mlp_10)
print("Multi Layer Perceptron 100 нейронов:", accuracy_mlp_100)
plt.figure(figsize=(12, 9))
plt.subplot(221)
plt.contourf(xx, yy, Z_linear_regression, alpha=0.8)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', alpha=0.6)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolors='k')
plt.title('Линейная регрессия')
plt.xlabel('Признак 1')
plt.ylabel('Признак 2')
plt.subplot(222)
plt.contourf(xx, yy, Z_mlp_10, alpha=0.8)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', alpha=0.6)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolors='k')
plt.title('MLP 10 нейронов')
plt.xlabel('Признак 1')
plt.ylabel('Признак 2')
plt.subplot(223)
plt.contourf(xx, yy, Z_mlp_100, alpha=0.8)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolors='k', alpha=0.6)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolors='k')
plt.title('MLP 100 нейронов')
plt.xlabel('Признак 1')
plt.ylabel('Признак 2')
plt.tight_layout()
plt.show()