Lasso JAX¶
In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.
The JAX package is used to compute gradients.
1# %% alpaqa lasso example
2
3import numpy as np
4import alpaqa as pa
5import jax
6import jax.numpy as jnp
7from jax import grad, jit
8from pprint import pprint
9
10jax.config.update("jax_enable_x64", True)
11
12scale = 5000
13n, m = scale, scale * 2
14sparsity = 0.02
15
16# %% Generate some data
17
18rng = np.random.default_rng(seed=123)
19# Random data matrix A
20A = rng.uniform(-1, 1, (m, n))
21# Sparse solution x_exact
22x_exact = rng.uniform(-0.1, 1, n)
23x_exact[rng.uniform(0, 1, n) > sparsity] = 0
24# Noisy right-hand side b
25b = A @ x_exact + rng.normal(0, 0.1, m)
26
27# %% Build the problem
28
29# Quadratic loss plus l1-regularization
30# minimize ½‖Ax - b‖² + λ‖x‖₁
31
32λ = 0.0025 * m
33
34
35def loss(x):
36 err = A @ x - b
37 return 0.5 * jnp.dot(err, err)
38
39
40class LassoProblem(pa.BoxConstrProblem):
41 def __init__(self):
42 super().__init__(n, 0)
43 self.variable_bounds.lower[:] = 0 # Positive lasso
44 self.l1_reg = [λ] # Regularization
45 self.jit_loss = jit(loss)
46 self.jit_grad_loss = jit(grad(loss))
47
48 def eval_objective(self, x): # Cost function
49 return self.jit_loss(x)
50
51 def eval_objective_gradient(self, x, grad_f): # Gradient of the cost
52 grad_f[:] = self.jit_grad_loss(x)
53
54
55prob = LassoProblem()
56
57# %% Solve the problem using alpaqa's PANOC solver
58
59opts = {
60 "max_iter": 100,
61 "stop_crit": pa.FPRNorm,
62 # Use a laxer tolerance because large problems have more numerical errors:
63 "quadratic_upperbound_tolerance_factor": 1e-12,
64}
65direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec_factor": 0})
66# direction = pa.LBFGSDirection({"memory": 5})
67solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
68# Add evaluation counters to the problem
69cnt = pa.problem_with_counters(prob)
70# Solve the problem
71sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
72
73# %% Print the results
74
75final_f = prob.eval_objective(sol)
76print()
77pprint(stats)
78print()
79print("Evaluations:")
80print(cnt.evaluations)
81print(f"Cost: {final_f + stats['final_h']}")
82print(f"Loss: {final_f}")
83print(f"Regularizer: {stats['final_h']}")
84print(f"FP Residual: {stats['ε']}")
85print(f"Run time: {stats['elapsed_time']}")
86print(stats["status"])
87
88# %% Plot the results
89
90import matplotlib.pyplot as plt
91
92plt.figure(figsize=(8, 5))
93plt.plot(x_exact, ".-", label="True solution")
94plt.plot(sol, ".-", label="Estimated solution")
95plt.legend()
96plt.title("PANOC lasso example: solution")
97plt.tight_layout()
98plt.figure(figsize=(8, 5))
99plt.plot(A @ x_exact, ".-", label="True solution")
100plt.plot(A @ sol, ".-", label="Estimated solution")
101plt.legend()
102plt.title("PANOC lasso example: right-hand side")
103plt.tight_layout()
104plt.show()