IIS_2023_1/martysheva_tamara_lab_1/lab1.py

57 lines
2.3 KiB
Python

import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from matplotlib.colors import ListedColormap
from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression, Perceptron, Ridge
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
#Создаем набор данных
X, Y = make_classification(n_samples=500, n_features=2, n_redundant=0, n_informative=2, random_state=0, n_clusters_per_class=1)
rng = np.random.RandomState(2)
X += 2 * rng.uniform(size=X.shape)
X = StandardScaler().fit_transform(X)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.4, random_state=40)
#Создаем модели
linear = LinearRegression()
perseptron = Perceptron()
ridge = Ridge(alpha=1.0)
polynomial_features = PolynomialFeatures(degree=3)
rid_poly = Pipeline([("polynomial_features", polynomial_features),("ridge_regression", ridge)])
#Тренируем модель
def train(model, description):
model.fit(X_train, Y_train)
Y_pred = model.predict(X_test)
print(description + ", качество модели = ", model.score(X_test, Y_test))
#Выводим результат на график
def plot(model, name):
cmap = ListedColormap(['#8b00ff', '#ff294d'])
plt.figure(figsize=(10, 7))
subplot = plt.subplot(111)
h = .5 # шаг регулярной сетки
x0_min, x0_max = X[:, 0].min() - .5, X[:, 0].max() + .5
x1_min, x1_max = X[:, 1].min() - .5, X[:, 1].max() + .5
xx0, xx1 = np.meshgrid(np.arange(x0_min, x0_max, h), np.arange(x1_min, x1_max, h))
Z = model.predict(np.c_[xx0.ravel(), xx1.ravel()])
Z = Z.reshape(xx0.shape)
subplot.contourf(xx0, xx1, Z, cmap=cmap, alpha=.3)
subplot.scatter(X_train[:, 0], X_train[:, 1], c=Y_train, cmap=cmap)
subplot.scatter(X_test[:, 0], X_test[:, 1], c=Y_test, cmap=cmap, alpha=0.4)
plt.savefig(name + ".png")
#Вызов функций
train(linear, "Линейная регрессия")
train(perseptron, "Персептрон")
train(rid_poly, "Гребневая полиномиальная регрессия")
plot(linear, "linear_plot")
plot(perseptron, "perseptron_plot")
plot(rid_poly, "rid_poly_plot")