Lasso JAX#

In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.

The JAX package is used to compute gradients.

True and estimated solution.
  1# %% alpaqa lasso example
  2
  3import alpaqa as pa
  4from jax.config import config
  5import jax.numpy as jnp
  6from jax import grad, jit
  7from jax import random
  8from pprint import pprint
  9
 10config.update("jax_enable_x64", True)
 11
 12scale = 5000
 13n, m = scale, scale * 2
 14sparsity = 0.02
 15key = random.PRNGKey(0)
 16
 17# %% Generate some data
 18
 19key, *subkeys = random.split(key, 5)
 20# Random data matrix A
 21A = random.uniform(subkeys[0], (m, n), minval=-1, maxval=1)
 22# Sparse solution x_exact
 23x_exact = random.uniform(subkeys[1], (n,), minval=-0.1, maxval=1)
 24x_exact_zeros = random.uniform(subkeys[2], (n,), minval=0, maxval=1) > sparsity
 25x_exact = x_exact.at[x_exact_zeros].set(0)
 26# Noisy right-hand side b
 27b = A @ x_exact + 0.1 * random.normal(subkeys[3], (m,))
 28
 29# %% Build the problem
 30
 31# Quadratic loss plus l1-regularization
 32# minimize  ½‖Ax - b‖² + λ‖x‖₁
 33
 34λ = 0.0025 * m
 35
 36
 37def loss(x):
 38    err = A @ x - b
 39    return 0.5 * jnp.dot(err, err)
 40
 41
 42class LassoProblem(pa.BoxConstrProblem):
 43    def __init__(self):
 44        super().__init__(n, 0)
 45        self.C.lowerbound[:] = 0  # Positive lasso
 46        self.l1_reg = [λ]  # Regularization
 47        self.jit_loss = jit(loss)
 48        self.jit_grad_loss = jit(grad(loss))
 49
 50    def eval_f(self, x):  # Cost function
 51        return self.jit_loss(x)
 52
 53    def eval_grad_f(self, x, grad_f):  # Gradient of the cost
 54        grad_f[:] = self.jit_grad_loss(x)
 55
 56
 57prob = LassoProblem()
 58
 59# %% Solve the problem using alpaqa's PANOC solver
 60
 61opts = {
 62    "max_iter": 100,
 63    "stop_crit": pa.FPRNorm,
 64    # Use a laxer tolerance because large problems have more numerical errors:
 65    "quadratic_upperbound_tolerance_factor": 1e-12,
 66}
 67direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec": False})
 68# direction = pa.LBFGSDirection({"memory": 5})
 69solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
 70# Add evaluation counters to the problem
 71cnt = pa.problem_with_counters(prob)
 72# Solve the problem
 73sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
 74
 75# %% Print the results
 76
 77final_f = prob.eval_f(sol)
 78print()
 79pprint(stats)
 80print()
 81print("Evaluations:")
 82print(cnt.evaluations)
 83print(f"Cost:          {final_f + stats['final_h']}")
 84print(f"Loss:          {final_f}")
 85print(f"Regularizer:   {stats['final_h']}")
 86print(f"FP Residual:   {stats['ε']}")
 87print(f"Run time:      {stats['elapsed_time']}")
 88print(stats["status"])
 89
 90# %% Plot the results
 91
 92import matplotlib.pyplot as plt
 93
 94plt.figure(figsize=(8, 5))
 95plt.plot(x_exact, ".-", label="True solution")
 96plt.plot(sol, ".-", label="Estimated solution")
 97plt.legend()
 98plt.title("PANOC lasso example: solution")
 99plt.tight_layout()
100plt.figure(figsize=(8, 5))
101plt.plot(A @ x_exact, ".-", label="True solution")
102plt.plot(A @ sol, ".-", label="Estimated solution")
103plt.legend()
104plt.title("PANOC lasso example: right-hand side")
105plt.tight_layout()
106plt.show()