Lasso JAX¶
In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity. Additionally, we impose a positivity constraint on the solution.
The JAX package is used to compute gradients.
1# %% alpaqa lasso example
2
3from pprint import pprint
4
5import jax
6import jax.numpy as jnp
7import numpy as np
8from jax import grad, jit
9
10import alpaqa as pa
11
12jax.config.update("jax_enable_x64", True)
13
14scale = 5000
15n, m = scale, scale * 2
16sparsity = 0.02
17
18# %% Generate some data
19
20rng = np.random.default_rng(seed=123)
21# Random data matrix A
22A = rng.uniform(-1, 1, (m, n))
23# Sparse solution x_exact
24x_exact = rng.uniform(-0.1, 1, n)
25x_exact[rng.uniform(0, 1, n) > sparsity] = 0
26# Noisy right-hand side b
27b = A @ x_exact + rng.normal(0, 0.1, m)
28
29# %% Build the problem
30
31# Quadratic loss plus l1-regularization
32# minimize ½‖Ax - b‖² + λ‖x‖₁
33
34λ = 0.0025 * m
35
36
37def loss(x):
38 err = A @ x - b
39 return 0.5 * jnp.dot(err, err)
40
41
42class LassoProblem(pa.BoxConstrProblem):
43 def __init__(self):
44 super().__init__(n, 0)
45 self.variable_bounds.lower[:] = 0 # Positive lasso
46 self.l1_reg = [λ] # Regularization
47 self.jit_loss = jit(loss)
48 self.jit_grad_loss = jit(grad(loss))
49
50 def eval_objective(self, x): # Cost function
51 return self.jit_loss(x)
52
53 def eval_objective_gradient(self, x, grad_f): # Gradient of the cost
54 grad_f[:] = self.jit_grad_loss(x)
55
56
57prob = LassoProblem()
58
59# %% Solve the problem using alpaqa's PANOC solver
60
61opts = {
62 "max_iter": 100,
63 "stop_crit": pa.FPRNorm,
64 # Use a laxer tolerance because large problems have more numerical errors:
65 "quadratic_upperbound_tolerance_factor": 1e-12,
66}
67direction = pa.StructuredLBFGSDirection({"memory": 5}, {"hessian_vec_factor": 0})
68# direction = pa.LBFGSDirection({"memory": 5})
69solver = pa.PANOCSolver({"print_interval": 5} | opts, direction)
70# Add evaluation counters to the problem
71cnt = pa.problem_with_counters(prob)
72# Solve the problem
73sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
74
75# %% Print the results
76
77final_f = prob.eval_objective(sol)
78print()
79pprint(stats)
80print()
81print("Evaluations:")
82print(cnt.evaluations)
83print(f"Cost: {final_f + stats['final_h']}")
84print(f"Loss: {final_f}")
85print(f"Regularizer: {stats['final_h']}")
86print(f"FP Residual: {stats['ε']}")
87print(f"Run time: {stats['elapsed_time']}")
88print(stats["status"])
89
90# %% Plot the results
91
92import matplotlib.pyplot as plt
93
94plt.figure(figsize=(8, 5))
95plt.plot(x_exact, ".-", label="True solution")
96plt.plot(sol, ".-", label="Estimated solution")
97plt.legend()
98plt.title("PANOC lasso example: solution")
99plt.tight_layout()
100plt.figure(figsize=(8, 5))
101plt.plot(A @ x_exact, ".-", label="True solution")
102plt.plot(A @ sol, ".-", label="Estimated solution")
103plt.legend()
104plt.title("PANOC lasso example: right-hand side")
105plt.tight_layout()
106plt.show()