from matplotlib import pyplot as plt
from sklearn.linear_model import LinearRegression
from RadomizedLasso import RandomizedLasso
from sklearn.feature_selection import RFE
from sklearn.preprocessing import MinMaxScaler
import numpy as np

names = ["x%s" % i for i in range(1, 15)]


def start_point():
    X,Y = generation_data()
    # Линейная модель
    lr = LinearRegression()
    lr.fit(X, Y)
    # Рекурсивное сокращение признаков
    rfe = RFE(lr)
    rfe.fit(X, Y)
    # Случайное Лассо
    randomized_lasso = RandomizedLasso(alpha=.01)
    randomized_lasso.fit(X, Y)

    ranks = {"Linear Regression": rank_to_dict(lr.coef_), "Recursive Feature Elimination": rank_to_dict(rfe.ranking_),
             "Randomize Lasso": rank_to_dict(randomized_lasso.coef_)}

    get_estimation(ranks)
    print_sorted_data(ranks)


def generation_data():
    np.random.seed(0)
    size = 750
    X = np.random.uniform(0, 1, (size, 14))
    Y = (10 * np.sin(np.pi * X[:, 0] * X[:, 1]) + 20 * (X[:, 2] - .5) ** 2 +
         10 * X[:, 3] + 5 * X[:, 4] ** 5 + np.random.normal(0, 1))
    X[:, 10:] = X[:, :4] + np.random.normal(0, .025, (size, 4))
    return X, Y


def rank_to_dict(ranks):
    ranks = np.abs(ranks)
    minmax = MinMaxScaler()
    ranks = minmax.fit_transform(np.array(ranks).reshape(14, 1)).ravel()
    ranks = map(lambda x: round(x, 2), ranks)
    return dict(zip(names, ranks))


def get_estimation(ranks: {}):
    mean = {}
    #«Бежим» по списку ranks
    for key, value in ranks.items():
        for item in value.items():
            if(item[0] not in mean):
                mean[item[0]] = 0
            mean[item[0]] += item[1]

    for key, value in mean.items():
        res = value/len(ranks)
        mean[key] = round(res, 2)

    mean_sorted = sorted(mean.items(), key=lambda item: item[1], reverse=True)
    print("Средние значения")
    print(mean_sorted)


    print("4 самых важных признака по среднему значению")
    for item in mean_sorted[:4]:
        print('Параметр - {0}, значение - {1}'.format(item[0], item[1]))



def print_sorted_data(ranks: {}):
    print()
    for key, value in ranks.items():
        ranks[key] = sorted(value.items(), key=lambda item: item[1], reverse=True)
    for key, value in ranks.items():
        print(key)
        print(value)


start_point()