import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from matplotlib.colors import ListedColormap
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression, Perceptron, Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline

#Создаем набор данных
X, Y = make_classification(n_samples=500, n_features=2, n_redundant=0, n_informative=2, random_state=0, n_clusters_per_class=1)
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)
X = StandardScaler().fit_transform(X)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.4, random_state=40)

#Создаем модели
linear = LinearRegression()
perseptron = Perceptron()
ridge = Ridge(alpha=1.0)
polynomial_features = PolynomialFeatures(degree=3)
rid_poly = Pipeline([("polynomial_features", polynomial_features),("ridge_regression", ridge)])

#Тренируем модель
def train(model, description):
    model.fit(X_train, Y_train)
    Y_pred = model.predict(X_test)
    print(description + ", качество модели = ", model.score(X_test, Y_test))

#Выводим результат на график
def plot(model, name):
    cmap = ListedColormap(['#8b00ff', '#ff294d'])
    plt.figure(figsize=(10, 7))
    subplot = plt.subplot(111)
    h = .5  # шаг регулярной сетки
    x0_min, x0_max = X[:, 0].min() - .5, X[:, 0].max() + .5
    x1_min, x1_max = X[:, 1].min() - .5, X[:, 1].max() + .5
    xx0, xx1 = np.meshgrid(np.arange(x0_min, x0_max, h), np.arange(x1_min, x1_max, h))

    Z = model.predict(np.c_[xx0.ravel(), xx1.ravel()])
    Z = Z.reshape(xx0.shape)
    subplot.contourf(xx0, xx1, Z, cmap=cmap, alpha=.3)

    subplot.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap=cmap)
    subplot.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, cmap=cmap, alpha=0.4)

    plt.savefig(name + ".png")

#Вызов функций
train(linear, "Линейная регрессия")
train(perseptron, "Персептрон")
train(rid_poly, "Гребневая полиномиальная регрессия")

plot(linear, "linear_plot")
plot(perseptron, "perseptron_plot")
plot(rid_poly, "rid_poly_plot")