IIS_2023_1/basharin_sevastyan_lab_1/main.py

61 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from random import randrange
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.datasets import make_circles
rs = randrange(50)
X, y = make_circles(noise=0.2, factor=0.5, random_state=rs) # Сгенерируем данные
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
random_state=rs) # Разделим данные на обучающий и тестовый наборы
# Линейная модель
linear_reg = LinearRegression()
# Полиномиальная регрессия (со степенью 4)
poly_reg = make_pipeline(PolynomialFeatures(degree=4), StandardScaler(), LogisticRegression(random_state=rs))
# Гребневая полиномиальная регрессия (со степенью 4 и alpha=1.0)
ridge_poly_reg = make_pipeline(PolynomialFeatures(degree=4), StandardScaler(), LogisticRegression(penalty='l2', C=1.0,
random_state=rs))
# Обучение моделей
def mid_sq_n_det(name, model):
model.fit(X_train, y_train)
y_predict = model.predict(X_test)
print(f'Рассчёт среднеквадратичной ошибки для {name}: '
f'{np.round(np.sqrt(metrics.mean_squared_error(y_test, y_predict)),3)}') # Рассчёт среднеквадратичной ошибки модели
print(f'Рассчёт коэфициента детерминации для {name}: {np.round(metrics.r2_score(y_test, y_predict), 2)}') # Рассчёт коэфициента детерминации модели
return name, model
# Графики
models = [mid_sq_n_det("Линейная регрессия", linear_reg),
mid_sq_n_det("Полиномиальная регрессия (со степенью 4)", poly_reg),
mid_sq_n_det("Гребневая полиномиальная регрессия (со степенью 4, alpha = 1.0)", ridge_poly_reg)]
cmap_background = ListedColormap(['#FFAAAA', '#AAAAFF'])
cmap_points = ListedColormap(['#FF0000', '#0000FF'])
plt.figure(figsize=(15, 4))
for i, (name, model) in enumerate(models):
plt.subplot(1, 3, i + 1)
xx, yy = np.meshgrid(np.linspace(X[:, 0].min() - 1, X[:, 0].max() + 1, 100),
np.linspace(X[:, 1].min() - 1, X[:, 1].max() + 1, 100))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=cmap_background, alpha=0.5)
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cmap_points, marker='o', label='Тестовые точки')
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cmap_points, marker='x', label='Обучающие точки')
plt.legend()
plt.title(name)
plt.text(0.5, -1.2, 'Красный класс', color='r', fontsize=12)
plt.text(0.5, -1.7, 'Синий класс', color='b', fontsize=12)
plt.tight_layout()
plt.show()