IIS_2023_1/simonov_nikita_lab_5/lab5.py

64 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error
import matplotlib.pyplot as plt
import numpy as np
import math
from sklearn.impute import SimpleImputer
# Загрузка данных
df = pd.read_csv('train_bikes.csv').dropna()
# Определение признаков (X) и целевой переменной (y)
X = df[['humidity', 'windspeed']]
y = df['count']
# Обработка пропущенных значений с использованием SimpleImputer
imputer = SimpleImputer(strategy='mean')
X = imputer.fit_transform(X)
# Разделение данных на обучающий, валидационный и тестовый наборы
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=0)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=0)
# Создание и обучение модели линейной регрессии
linear_model = LinearRegression()
linear_model.fit(X_train, y_train)
# Вывод коэффициентов и пересечения
print(f'Коэффициенты линейной регрессии: {linear_model.coef_}')
print(f'Пересечение линейной регрессии: {linear_model.intercept_}')
# Предсказание значений на тестовом наборе
y_pred = linear_model.predict(X_test)
# Оценка модели
train_score = linear_model.score(X_train, y_train)
val_score = linear_model.score(X_val, y_val)
test_score = linear_model.score(X_test, y_test)
print(f'R^2 на обучающем наборе: {train_score}')
print(f'R^2 на валидационном наборе: {val_score}')
print(f'R^2 на тестовом наборе: {test_score}')
# Оценка качества предсказаний
MSE = mean_squared_error(y_test, y_pred)
RMSE = math.sqrt(MSE)
print(f'Среднеквадратичная ошибка: {MSE}')
print(f'Корень из среднеквадратичной ошибки: {RMSE}')
# Применение стиля графика
plt.style.use(['dark_background'])
# Визуализация предсказаний
plt.figure(figsize=(8, 6), dpi=80)
plt.scatter(y_test, y_pred, alpha=0.2, color='red')
m, b = np.polyfit(y_test, y_pred, 1)
plt.plot(y_test, m * y_test + b, color='yellow')
plt.xlabel('Фактическое значение (тестовый набор)', fontsize=14)
plt.ylabel('Предсказанное значение (тестовый набор)', fontsize=14)
plt.title('Линейная регрессия: предсказанные и фактические значения (тестовый набор)', fontsize=16)
plt.grid(linewidth=0.5)
plt.show()