Lasso JAX#
In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.
The JAX package is used to compute gradients.
1# %% alpaqa lasso example
2
3import alpaqa as pa
4from jax.config import config
5import jax.numpy as jnp
6from jax import grad, jit
7from jax import random
8from pprint import pprint
9
10config.update("jax_enable_x64", True)
11
12scale = 5000
13n, m = scale, scale * 2
14sparsity = 0.02
15key = random.PRNGKey(0)
16
17# %% Generate some data
18
19key, *subkeys = random.split(key, 5)
20# Random data matrix A
21A = random.uniform(subkeys[0], (m, n), minval=-1, maxval=1)
22# Sparse solution x_exact
23x_exact = random.uniform(subkeys[1], (n,), minval=-0.1, maxval=1)
24x_exact_zeros = random.uniform(subkeys[2], (n,), minval=0, maxval=1) > sparsity
25x_exact = x_exact.at[x_exact_zeros].set(0)
26# Noisy right-hand side b
27b = A @ x_exact + 0.1 * random.normal(subkeys[3], (m,))
28
29# %% Build the problem
30
31# Quadratic loss plus l1-regularization
32# minimize ½‖Ax - b‖² + λ‖x‖₁
33
34λ = 0.0025 * m
35
36
37def loss(x):
38 err = A @ x - b
39 return 0.5 * jnp.dot(err, err)
40
41
42class LassoProblem(pa.BoxConstrProblem):
43 def __init__(self):
44 super().__init__(n, 0)
45 self.C.lowerbound[:] = 0 # Positive lasso
46 self.l1_reg = [λ] # Regularization
47 self.jit_loss = jit(loss)
48 self.jit_grad_loss = jit(grad(loss))
49
50 def eval_f(self, x): # Cost function
51 return self.jit_loss(x)
52
53 def eval_grad_f(self, x, grad_f): # Gradient of the cost
54 grad_f[:] = self.jit_grad_loss(x)
55
56
57prob = LassoProblem()
58
59# %% Solve the problem using alpaqa's PANOC solver
60
61opts = {
62 "max_iter": 100,
63 "stop_crit": pa.FPRNorm,
64 # Use a laxer tolerance because large problems have more numerical errors:
65 "quadratic_upperbound_tolerance_factor": 1e-12,
66}
67direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec": False})
68# direction = pa.LBFGSDirection({"memory": 5})
69solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
70# Add evaluation counters to the problem
71cnt = pa.problem_with_counters(prob)
72# Solve the problem
73sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
74
75# %% Print the results
76
77final_f = prob.eval_f(sol)
78print()
79pprint(stats)
80print()
81print("Evaluations:")
82print(cnt.evaluations)
83print(f"Cost: {final_f + stats['final_h']}")
84print(f"Loss: {final_f}")
85print(f"Regularizer: {stats['final_h']}")
86print(f"FP Residual: {stats['ε']}")
87print(f"Run time: {stats['elapsed_time']}")
88print(stats["status"])
89
90# %% Plot the results
91
92import matplotlib.pyplot as plt
93
94plt.figure(figsize=(8, 5))
95plt.plot(x_exact, ".-", label="True solution")
96plt.plot(sol, ".-", label="Estimated solution")
97plt.legend()
98plt.title("PANOC lasso example: solution")
99plt.tight_layout()
100plt.figure(figsize=(8, 5))
101plt.plot(A @ x_exact, ".-", label="True solution")
102plt.plot(A @ sol, ".-", label="Estimated solution")
103plt.legend()
104plt.title("PANOC lasso example: right-hand side")
105plt.tight_layout()
106plt.show()