55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.datasets import make_moons
|
|
from sklearn.linear_model import LinearRegression
|
|
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
|
from sklearn.metrics import accuracy_score, mean_squared_error
|
|
|
|
|
|
X, y = make_moons(noise=0.3, random_state=42)
|
|
|
|
# Линейная регрессия
|
|
lr = LinearRegression()
|
|
lr.fit(X, y)
|
|
|
|
# Многослойный персептрон
|
|
mlp = MLPRegressor(hidden_layer_sizes=(10,), alpha=0.01, random_state=42)
|
|
mlp.fit(X, y)
|
|
|
|
# Персептрон
|
|
perceptron = MLPClassifier(hidden_layer_sizes=(1,), random_state=42)
|
|
perceptron.fit(X, y)
|
|
|
|
# Создаем сетку точек для предсказания моделей
|
|
xx, yy = np.meshgrid(np.linspace(-2, 3, 1000), np.linspace(-2, 2, 1000))
|
|
Z_lr = lr.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
|
|
Z_mlp = mlp.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
|
|
Z_perceptron = perceptron.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
|
|
|
|
# Отображаем данные и предсказания моделей
|
|
plt.figure(figsize=(18, 6))
|
|
|
|
# График линейной регрессии
|
|
plt.subplot(1, 3, 1)
|
|
plt.contourf(xx, yy, Z_lr, cmap=plt.cm.RdBu, alpha=0.8)
|
|
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdBu, edgecolors='k')
|
|
plt.title("Линейная регрессия")
|
|
|
|
# График многослойного персептрона
|
|
plt.subplot(1, 3, 2)
|
|
plt.contourf(xx, yy, Z_mlp, cmap=plt.cm.RdBu, alpha=0.8)
|
|
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdBu, edgecolors='k')
|
|
plt.title("Многослойный персептрон")
|
|
|
|
# График персептрона
|
|
plt.subplot(1, 3, 3)
|
|
plt.contourf(xx, yy, Z_perceptron, cmap=plt.cm.RdBu, alpha=0.8)
|
|
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdBu, edgecolors='k')
|
|
plt.title("Персептрон")
|
|
|
|
print("Линейная регрессия:", lr.score(X, y))
|
|
print("Многослойный персептрон:", mlp.score(X, y))
|
|
print("Персептрон:", perceptron.score(X, y))
|
|
|
|
# Показываем графики
|
|
plt.show() |