from sklearn.utils import check_X_y, check_random_state
from sklearn.linear_model import Lasso
from scipy.sparse import issparse
from scipy import sparse


def _rescale_data(x, weights):
    if issparse(x):
        size = weights.shape[0]
        weight_dia = sparse.dia_matrix((1 - weights, 0), (size, size))
        x_rescaled = x * weight_dia
    else:
        x_rescaled = x * (1 - weights)

    return x_rescaled


class RandomizedLasso(Lasso):
    """
    Randomized version of scikit-learns Lasso class.

    Randomized LASSO is a generalization of the LASSO. The LASSO penalises
    the absolute value of the coefficients with a penalty term proportional
    to `alpha`, but the randomized LASSO changes the penalty to a randomly
    chosen value in the range `[alpha, alpha/weakness]`.

    Parameters
    ----------
    weakness : float
        Weakness value for randomized LASSO. Must be in (0, 1].

    See also
    --------
    sklearn.linear_model.LogisticRegression : learns logistic regression models
    using the same algorithm.
    """
    def __init__(self, weakness=0.5, alpha=1.0, fit_intercept=True,
                 precompute=False, copy_X=True, max_iter=1000,
                 tol=1e-4, warm_start=False, positive=False,
                 random_state=None, selection='cyclic'):
        self.weakness = weakness
        super(RandomizedLasso, self).__init__(
            alpha=alpha, fit_intercept=fit_intercept, precompute=precompute, copy_X=copy_X,
            max_iter=max_iter, tol=tol, warm_start=warm_start,
            positive=positive, random_state=random_state,
            selection=selection)

    def fit(self, X, y):
        """Fit the model according to the given training data.

        Parameters
        ----------
        X : {array-like, sparse matrix}, shape = [n_samples, n_features]
            The training input samples.

        y : array-like, shape = [n_samples]
            The target values.
        """
        if not isinstance(self.weakness, float) or not (0.0 < self.weakness <= 1.0):
            raise ValueError('weakness should be a float in (0, 1], got %s' % self.weakness)

        X, y = check_X_y(X, y, accept_sparse=True)

        n_features = X.shape[1]
        weakness = 1. - self.weakness
        random_state = check_random_state(self.random_state)

        weights = weakness * random_state.randint(0, 1 + 1, size=(n_features,))

        # TODO: I am afraid this will do double normalization if set to true
        #X, y, _, _ = _preprocess_data(X, y, self.fit_intercept, normalize=self.normalize, copy=False,
        #             sample_weight=None, return_mean=False)

        # TODO: Check if this is a problem if it happens before standardization
        X_rescaled = _rescale_data(X, weights)
        return super(RandomizedLasso, self).fit(X_rescaled, y)