import numpy as np from sklearn.datasets import make_classification from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split from matplotlib.colors import ListedColormap from matplotlib import pyplot as plt from sklearn.linear_model import LinearRegression, Perceptron, Ridge from sklearn.preprocessing import PolynomialFeatures from sklearn.pipeline import Pipeline #Создаем набор данных X, Y = make_classification(n_samples=500, n_features=2, n_redundant=0, n_informative=2, random_state=0, n_clusters_per_class=1) rng = np.random.RandomState(2) X += 2 * rng.uniform(size=X.shape) X = StandardScaler().fit_transform(X) X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.4, random_state=40) #Создаем модели linear = LinearRegression() perseptron = Perceptron() ridge = Ridge(alpha=1.0) polynomial_features = PolynomialFeatures(degree=3) rid_poly = Pipeline([("polynomial_features", polynomial_features),("ridge_regression", ridge)]) #Тренируем модель def train(model, description): model.fit(X_train, Y_train) Y_pred = model.predict(X_test) print(description + ", качество модели = ", model.score(X_test, Y_test)) #Выводим результат на график def plot(model, name): cmap = ListedColormap(['#8b00ff', '#ff294d']) plt.figure(figsize=(10, 7)) subplot = plt.subplot(111) h = .5 # шаг регулярной сетки x0_min, x0_max = X[:, 0].min() - .5, X[:, 0].max() + .5 x1_min, x1_max = X[:, 1].min() - .5, X[:, 1].max() + .5 xx0, xx1 = np.meshgrid(np.arange(x0_min, x0_max, h), np.arange(x1_min, x1_max, h)) Z = model.predict(np.c_[xx0.ravel(), xx1.ravel()]) Z = Z.reshape(xx0.shape) subplot.contourf(xx0, xx1, Z, cmap=cmap, alpha=.3) subplot.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap=cmap) subplot.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, cmap=cmap, alpha=0.4) plt.savefig(name + ".png") #Вызов функций train(linear, "Линейная регрессия") train(perseptron, "Персептрон") train(rid_poly, "Гребневая полиномиальная регрессия") plot(linear, "linear_plot") plot(perseptron, "perseptron_plot") plot(rid_poly, "rid_poly_plot")