IIS_2023_1/shadaev_anton_lab_5/main.py
2023-11-04 20:32:30 +04:00

33 lines
1.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
from sklearn.linear_model import Ridge
from sklearn import metrics
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
# Загрузка данных и разделение на обучающий и тестовый наборы
data = pd.read_csv('stroke_prediction_ds.csv')
data = data.dropna()
scaler = MinMaxScaler()
x = data[['age', 'hypertension', 'heart_disease']] # выделение признаков (возраст, гипертензия, сердечные заболевания)
y = scaler.fit_transform(data['avg_glucose_level'].values.reshape(-1, 1)).flatten() # масштабирование данных
split_point = round(len(data) * 0.99) # 99% данных
x_train, x_test = x.iloc[:split_point], x.iloc[split_point:] # x_train читает 99% данных, а x_test - оставшийся 1%
y_train, y_test = y[:split_point], y[split_point:] # y_train читает 99% данных, а y_test - оставшийся 1%
# Обучение модели и прогнозирование, применение алгоритма гребневой регрессии
rid = Ridge(alpha=1.0).fit(x_train.values, y_train)
y_predict = rid.predict(x_test.values)
# Вычисление метрик и построение графика (среднеквадратичная ошибка и коэффициент детерминации)
mid_square = round(metrics.mean_squared_error(y_test, y_predict) ** 0.5, 3)
coeff_determ = round(metrics.r2_score(y_test, y_predict), 2)
# Визуализация данных
plt.plot(y_test, c="red", label="y_test")
plt.plot(y_predict, c="orange", label=f"y_pred\nmean_squared_error (mid_square) = {mid_square}\nCoefficient of determination (coeff_determ) = {coeff_determ}")
plt.legend(loc='upper right')
plt.title("Гребневая регрессия")
plt.show()