2023-09-30 20:26:46 +04:00

50 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from sklearn.datasets import make_regression
from sklearn.feature_selection import RFECV, f_regression
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
# генерируем исходные данные: 100 строк-наблюдений и 10 столбцов-признаков
X, y = make_regression(n_samples=100, n_features=10, random_state=42)
# линейная модель
linear_reg = LinearRegression()
linear_reg.fit(X, y)
linear_ranking_lr = np.abs(linear_reg.coef_)
# cокращение признаков cлучайными деревьями (Random Forest Regressor)
rf_reg = RandomForestRegressor()
rfecv = RFECV(estimator=rf_reg)
rfecv.fit(X, y)
rfecv_ranking = rfecv.ranking_
# линейная корреляция (f_regression)
f_reg, _ = f_regression(X, y)
linear_corr_ranking = f_reg
# ранжирование признаков и вычисление средней оценки
all_rankings = np.vstack((linear_ranking_lr, rfecv_ranking, linear_corr_ranking))
average_ranking = np.mean(all_rankings, axis=0)
# средние показатели четырех наиболее важных характеристик
most_important_indices = np.argsort(average_ranking)[-4:]
# результаты
print("ранги линейной модели:")
print(linear_ranking_lr)
print("")
print("ранги после сокращения признаков Random Forest:")
print(rfecv_ranking)
print("")
print("ранги линейнейной корреляции (f_regression):")
print(linear_corr_ranking)
print("")
print("ранги по средней оценке:")
print(average_ranking)
print("")
print("4 выделенных главных признака:")
print(most_important_indices)