Lasso JAX¶
In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.
The JAX package is used to compute gradients.
  1# %% alpaqa lasso example
  2
  3import numpy as np
  4import alpaqa as pa
  5import jax
  6import jax.numpy as jnp
  7from jax import grad, jit
  8from pprint import pprint
  9
 10jax.config.update("jax_enable_x64", True)
 11
 12scale = 5000
 13n, m = scale, scale * 2
 14sparsity = 0.02
 15
 16# %% Generate some data
 17
 18rng = np.random.default_rng(seed=123)
 19# Random data matrix A
 20A = rng.uniform(-1, 1, (m, n))
 21# Sparse solution x_exact
 22x_exact = rng.uniform(-0.1, 1, n)
 23x_exact[rng.uniform(0, 1, n) > sparsity] = 0
 24# Noisy right-hand side b
 25b = A @ x_exact + rng.normal(0, 0.1, m)
 26
 27# %% Build the problem
 28
 29# Quadratic loss plus l1-regularization
 30# minimize  ½‖Ax - b‖² + λ‖x‖₁
 31
 32λ = 0.0025 * m
 33
 34
 35def loss(x):
 36    err = A @ x - b
 37    return 0.5 * jnp.dot(err, err)
 38
 39
 40class LassoProblem(pa.BoxConstrProblem):
 41    def __init__(self):
 42        super().__init__(n, 0)
 43        self.variable_bounds.lower[:] = 0  # Positive lasso
 44        self.l1_reg = [λ]  # Regularization
 45        self.jit_loss = jit(loss)
 46        self.jit_grad_loss = jit(grad(loss))
 47
 48    def eval_objective(self, x):  # Cost function
 49        return self.jit_loss(x)
 50
 51    def eval_objective_gradient(self, x, grad_f):  # Gradient of the cost
 52        grad_f[:] = self.jit_grad_loss(x)
 53
 54
 55prob = LassoProblem()
 56
 57# %% Solve the problem using alpaqa's PANOC solver
 58
 59opts = {
 60    "max_iter": 100,
 61    "stop_crit": pa.FPRNorm,
 62    # Use a laxer tolerance because large problems have more numerical errors:
 63    "quadratic_upperbound_tolerance_factor": 1e-12,
 64}
 65direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec_factor": 0})
 66# direction = pa.LBFGSDirection({"memory": 5})
 67solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
 68# Add evaluation counters to the problem
 69cnt = pa.problem_with_counters(prob)
 70# Solve the problem
 71sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
 72
 73# %% Print the results
 74
 75final_f = prob.eval_objective(sol)
 76print()
 77pprint(stats)
 78print()
 79print("Evaluations:")
 80print(cnt.evaluations)
 81print(f"Cost:          {final_f + stats['final_h']}")
 82print(f"Loss:          {final_f}")
 83print(f"Regularizer:   {stats['final_h']}")
 84print(f"FP Residual:   {stats['ε']}")
 85print(f"Run time:      {stats['elapsed_time']}")
 86print(stats["status"])
 87
 88# %% Plot the results
 89
 90import matplotlib.pyplot as plt
 91
 92plt.figure(figsize=(8, 5))
 93plt.plot(x_exact, ".-", label="True solution")
 94plt.plot(sol, ".-", label="Estimated solution")
 95plt.legend()
 96plt.title("PANOC lasso example: solution")
 97plt.tight_layout()
 98plt.figure(figsize=(8, 5))
 99plt.plot(A @ x_exact, ".-", label="True solution")
100plt.plot(A @ sol, ".-", label="Estimated solution")
101plt.legend()
102plt.title("PANOC lasso example: right-hand side")
103plt.tight_layout()
104plt.show()