Lasso JAX

In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.

The JAX package is used to compute gradients.

True and estimated solution.
  1# %% alpaqa lasso example
  2
  3from pprint import pprint
  4
  5import jax
  6import jax.numpy as jnp
  7import numpy as np
  8from jax import grad, jit
  9
 10import alpaqa as pa
 11
 12jax.config.update("jax_enable_x64", True)
 13
 14scale = 5000
 15n, m = scale, scale * 2
 16sparsity = 0.02
 17
 18# %% Generate some data
 19
 20rng = np.random.default_rng(seed=123)
 21# Random data matrix A
 22A = rng.uniform(-1, 1, (m, n))
 23# Sparse solution x_exact
 24x_exact = rng.uniform(-0.1, 1, n)
 25x_exact[rng.uniform(0, 1, n) > sparsity] = 0
 26# Noisy right-hand side b
 27b = A @ x_exact + rng.normal(0, 0.1, m)
 28
 29# %% Build the problem
 30
 31# Quadratic loss plus l1-regularization
 32# minimize  ½‖Ax - b‖² + λ‖x‖₁
 33
 34λ = 0.0025 * m
 35
 36
 37def loss(x):
 38    err = A @ x - b
 39    return 0.5 * jnp.dot(err, err)
 40
 41
 42class LassoProblem(pa.BoxConstrProblem):
 43    def __init__(self):
 44        super().__init__(n, 0)
 45        self.variable_bounds.lower[:] = 0  # Positive lasso
 46        self.l1_reg = [λ]  # Regularization
 47        self.jit_loss = jit(loss)
 48        self.jit_grad_loss = jit(grad(loss))
 49
 50    def eval_objective(self, x):  # Cost function
 51        return self.jit_loss(x)
 52
 53    def eval_objective_gradient(self, x, grad_f):  # Gradient of the cost
 54        grad_f[:] = self.jit_grad_loss(x)
 55
 56
 57prob = LassoProblem()
 58
 59# %% Solve the problem using alpaqa's PANOC solver
 60
 61opts = {
 62    "max_iter": 100,
 63    "stop_crit": pa.FPRNorm,
 64    # Use a laxer tolerance because large problems have more numerical errors:
 65    "quadratic_upperbound_tolerance_factor": 1e-12,
 66}
 67direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec_factor": 0})
 68# direction = pa.LBFGSDirection({"memory": 5})
 69solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
 70# Add evaluation counters to the problem
 71cnt = pa.problem_with_counters(prob)
 72# Solve the problem
 73sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
 74
 75# %% Print the results
 76
 77final_f = prob.eval_objective(sol)
 78print()
 79pprint(stats)
 80print()
 81print("Evaluations:")
 82print(cnt.evaluations)
 83print(f"Cost:          {final_f + stats['final_h']}")
 84print(f"Loss:          {final_f}")
 85print(f"Regularizer:   {stats['final_h']}")
 86print(f"FP Residual:   {stats['ε']}")
 87print(f"Run time:      {stats['elapsed_time']}")
 88print(stats["status"])
 89
 90# %% Plot the results
 91
 92import matplotlib.pyplot as plt
 93
 94plt.figure(figsize=(8, 5))
 95plt.plot(x_exact, ".-", label="True solution")
 96plt.plot(sol, ".-", label="Estimated solution")
 97plt.legend()
 98plt.title("PANOC lasso example: solution")
 99plt.tight_layout()
100plt.figure(figsize=(8, 5))
101plt.plot(A @ x_exact, ".-", label="True solution")
102plt.plot(A @ sol, ".-", label="Estimated solution")
103plt.legend()
104plt.title("PANOC lasso example: right-hand side")
105plt.tight_layout()
106plt.show()