55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
|
import matplotlib.pyplot as plt
|
||
|
import numpy as np
|
||
|
from sklearn.datasets import make_classification
|
||
|
from sklearn.model_selection import train_test_split
|
||
|
from sklearn.linear_model import Perceptron
|
||
|
from sklearn.neural_network import MLPClassifier
|
||
|
from sklearn.metrics import accuracy_score
|
||
|
|
||
|
# Установите random_state, чтобы результаты были воспроизводимыми
|
||
|
rs = 42
|
||
|
|
||
|
# Генерация данных
|
||
|
X, y = make_classification(
|
||
|
n_samples=500, n_features=2, n_redundant=0, n_informative=2,
|
||
|
random_state=rs, n_clusters_per_class=1
|
||
|
)
|
||
|
|
||
|
# Разделение данных на обучающий и тестовый наборы
|
||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=rs)
|
||
|
|
||
|
# Создание моделей
|
||
|
models = [
|
||
|
('Perceptron', Perceptron(random_state=rs)),
|
||
|
('MLP (10 neurons)', MLPClassifier(hidden_layer_sizes=(10,), alpha=0.01, random_state=rs)),
|
||
|
('MLP (100 neurons)', MLPClassifier(hidden_layer_sizes=(100,), alpha=0.01, random_state=rs))
|
||
|
]
|
||
|
|
||
|
# Обучение и оценка моделей
|
||
|
results = {}
|
||
|
|
||
|
plt.figure(figsize=(15, 5))
|
||
|
|
||
|
for i, (name, model) in enumerate(models, 1):
|
||
|
plt.subplot(1, 3, i)
|
||
|
model.fit(X_train, y_train)
|
||
|
y_pred = model.predict(X_test)
|
||
|
accuracy = accuracy_score(y_test, y_pred)
|
||
|
results[name] = accuracy
|
||
|
|
||
|
# Разбиение точек на классы
|
||
|
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=plt.cm.Paired, edgecolors='k')
|
||
|
|
||
|
# Построение границы решения для каждой модели
|
||
|
h = .02 # Шаг сетки
|
||
|
x_min, x_max = X_test[:, 0].min() - 1, X_test[:, 0].max() + 1
|
||
|
y_min, y_max = X_test[:, 1].min() - 1, X_test[:, 1].max() + 1
|
||
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
|
||
|
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
|
||
|
Z = Z.reshape(xx.shape)
|
||
|
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
|
||
|
|
||
|
plt.title(f'{name}\nAccuracy: {accuracy:.2f}')
|
||
|
|
||
|
plt.show()
|