Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

L1-constrained regression using Frank-Wolfe #43

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
59 changes: 59 additions & 0 deletions examples/plot_l1_reg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

"""
========================================
L1 regression: regularization paths
========================================

Shows that the regularization paths obtained by coordinate descent (penalized)
and Frank-Wolfe (constrained) are equivalent.
"""
print __doc__
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

from lightning.regression import CDRegressor
from lightning.regression import FWRegressor

diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target

X_tr, X_te, y_tr, y_te = train_test_split(X, y, train_size=0.75, random_state=0)

plt.figure()

betas = np.logspace(-2, 5, 50)
alphas = np.logspace(-4, 4, 50)

fw_n_nz = []
fw_error = []
cd_n_nz = []
cd_error = []

for beta in betas:
reg = FWRegressor(beta=beta, max_iter=1000, tol=1e-3, verbose=0)
reg.fit(X_tr, y_tr)
y_pred = reg.predict(X_te)
fw_n_nz.append(np.sum(reg.coef_ != 0))
fw_error.append(np.sqrt(np.mean((y_te - y_pred) ** 2)))

for alpha in alphas:
reg = CDRegressor(alpha=alpha, penalty="l1", max_iter=1000, tol=1e-3,
verbose=0)
reg.fit(X_tr, y_tr)
y_pred = reg.predict(X_te)
cd_n_nz.append(np.sum(reg.coef_ != 0))
cd_error.append(np.sqrt(np.mean((y_te - y_pred) ** 2)))

plt.plot(fw_n_nz, fw_error, label="Frank-Wolfe", linewidth=3)
plt.plot(cd_n_nz, cd_error, label="Coordinate Descent", linewidth=3, linestyle="--")

plt.xlabel("Number of non-zero coefficients")
plt.ylabel("RMSE")
plt.xlim((0, X_tr.shape[1]))
#plt.ylim((160, 170))
plt.legend()

plt.show()
114 changes: 114 additions & 0 deletions lightning/impl/fw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np
import scipy.sparse as sp

from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.utils.extmath import safe_sparse_dot


def _frank_wolfe(w_init, X, y, beta, max_iter=50, tol=1e-3, max_nz=None,
verbose=0):
"""
Solve

0.5 * ||np.dot(X, w) - y||^2 s.t. ||w||_1 <= beta

by the Frank-Wolfe method.

The method can be seen as a greedy coordinate descent: it adds at most one
non-zero coefficient per iteration.
"""
n_samples, n_features = X.shape

if sp.issparse(X):
X = X.tocsc()

w = w_init.copy()

for it in range(max_iter):
y_pred = safe_sparse_dot(X, w)
resid = beta * y_pred - y
neg_grad = -safe_sparse_dot(X.T, beta * resid)

atom = np.argmax(np.abs(neg_grad))
s = np.sign(neg_grad[atom])

error = np.dot(resid, resid)
dgap = s * neg_grad[atom] - np.dot(w, neg_grad)

if it == 0:
error_init = error
dgap_init = dgap

if verbose:
print "iter", it + 1
print "duality gap", dgap / dgap_init
print "error reduction", error / error_init
print "l1 norm", beta * np.sum(np.abs(w))
print "n_nz", np.sum(w != 0)
print

# Find optimal step size by exact line search.
Xs = s * X[:, atom]
if sp.issparse(Xs):
Xs_sq = np.dot(Xs.data, Xs.data)
else:
Xs_sq = np.dot(Xs, Xs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the sign doesn't affect this, you could precompute all column square norms outside of the loop, right? (But I guess it's a tradeoff for high dimensional X)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll do that. O(n_features) memory cache is not big deal.

y_pred_sq = np.dot(y_pred, y_pred)
b = (Xs - y_pred)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b seems unused

gamma = np.dot(resid, y_pred) - safe_sparse_dot(resid, Xs)
gamma /= beta * (Xs_sq - 2 * safe_sparse_dot(Xs.T, y_pred) + y_pred_sq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running the example I get a warning at this line, because Xs - y_pred is zero. The line below fixes gamma, but I thought it might be worth it to point it out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep I get a warning too. I need to check more throughly what's the right thing to do in this case: set gamma=1, set gamma=0 or stop the algorithm?

gamma = max(0, min(1, gamma))

# Update parameters.
w *= (1 - gamma)
w[atom] += gamma * s

# Stop if maximum number of non-zero coefficients is reached.
if max_nz is not None and np.sum(w != 0) == max_nz:
break

# Stop if desired duality gap tolerance is reached.
if dgap / dgap_init <= tol:
if verbose:
print "Converged"
break

w *= beta
return w


class FWRegressor(BaseEstimator, RegressorMixin):

def __init__(self, beta=1.0, max_iter=50, tol=1e-3, max_nz=None, verbose=0):
self.beta = beta
self.max_iter = max_iter
self.tol = tol
self.max_nz = max_nz
self.verbose = verbose

def fit(self, X, y):
n_features = X.shape[1]
coef = np.zeros(n_features)
self.coef_ = _frank_wolfe(coef, X, y, beta=self.beta,
max_iter=self.max_iter, tol=self.tol,
max_nz=self.max_nz, verbose=self.verbose)
return self

def predict(self, X):
return safe_sparse_dot(X, self.coef_)


if __name__ == '__main__':
from sklearn.datasets import load_diabetes
from sklearn.preprocessing import StandardScaler

diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target
X = StandardScaler().fit_transform(X)
#X = sp.csr_matrix(X)

reg = FWRegressor(beta=100, max_iter=1000, tol=1e-2, verbose=1)
reg.fit(X, y)
y_pred = reg.predict(X)
error = np.mean((y - y_pred) ** 2)
print error
1 change: 1 addition & 0 deletions lightning/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .impl.dual_cd import LinearSVR
from .impl.primal_cd import CDRegressor
from .impl.fista import FistaRegressor
from .impl.fw import FWRegressor
from .impl.sag import SAGRegressor
from .impl.sag import SAGARegressor
from .impl.sdca import SDCARegressor
Expand Down