45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
from sklearn.utils import check_X_y, check_random_state
|
|
from sklearn.linear_model import Lasso
|
|
from scipy.sparse import issparse
|
|
from pandas._libs import sparse
|
|
|
|
|
|
def _rescale_data(x, weights):
|
|
if issparse(x):
|
|
size = weights.shape[0]
|
|
weight_dia = sparse.dia_matrix((1 - weights, 0), (size, size))
|
|
x_rescaled = x * weight_dia
|
|
else:
|
|
x_rescaled = x * (1 - weights)
|
|
|
|
return x_rescaled
|
|
|
|
|
|
class RandomizedLasso(Lasso):
|
|
def __init__(self, weakness=0.5, alpha=1.0, fit_intercept=True, normalize=False,
|
|
precompute=False, copy_x=True, max_iter=1000,
|
|
tol=1e-4, warm_start=False, positive=False,
|
|
random_state=None, selection='cyclic'):
|
|
self.weakness = weakness
|
|
super(RandomizedLasso, self).__init__(
|
|
alpha=alpha, fit_intercept=fit_intercept,
|
|
normalize=normalize, precompute=precompute, copy_X=copy_x,
|
|
max_iter=max_iter, tol=tol, warm_start=warm_start,
|
|
positive=positive, random_state=random_state,
|
|
selection=selection)
|
|
|
|
def fit(self, x, y):
|
|
if not isinstance(self.weakness, float) or not (0.0 < self.weakness <= 1.0):
|
|
raise ValueError('weakness should be a float in (0, 1], got %s' % self.weakness)
|
|
|
|
x, y = check_X_y(x, y, accept_sparse=True)
|
|
|
|
n_features = x.shape[1]
|
|
weakness = 1. - self.weakness
|
|
random_state = check_random_state(self.random_state)
|
|
|
|
weights = weakness * random_state.randint(0, 1 + 1, size=(n_features,))
|
|
|
|
x_rescaled = _rescale_data(x, weights)
|
|
return super(RandomizedLasso, self).fit(x_rescaled, y)
|