IIS_2023_1/shadaev_anton_lab_3/stroke_prediction.py

42 lines
1.6 KiB
Python
Raw Normal View History

2023-11-03 18:11:03 +04:00
import pandas as pd
from sklearn.metrics import accuracy_score, mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
# Загрузка данных
data = pd.read_csv('stroke_prediction_ds.csv')
# Приведение данных к цифровому значению
data['gender'] = data['gender'].map({'Male': 0, 'Female': 1})
data['ever_married'] = data['ever_married'].map({'No': 0, 'Yes': 1})
# Определение признаков
X = data[['hypertension', 'heart_disease', 'ever_married', 'gender']]
# Целевая переменная
y = data['stroke']
# Разделение данных на обучающую и тестовую выборки
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Создание модели
dt_classifier = DecisionTreeClassifier(random_state=42)
# Обучение модели
dt_classifier.fit(X_train, y_train)
# Вычисление важности признаков
feature_importances = dt_classifier.feature_importances_
# Вычисление 'Accuracy' модели
predictions = dt_classifier.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
# Вычисление средней квадратичной ошибки
mse = mean_squared_error(y_test, predictions)
# Вывод результатов
print(X.head(10))
print(f'Важность признаков:" {feature_importances}')
print(f'Accuracy модели: {round(accuracy * 100)}%')
print("Средняя квадратичная ошибка: {:.2f}%".format(mse * 100))