Lasso#

In this example, we use the PANOC solver to solve a lasso problem, i.e. least squares with \(\ell_1\)-regularization to promote sparsity.

True and estimated solution.
 1# %% alpaqa lasso example
 2
 3import alpaqa as pa
 4import casadi as cs
 5import numpy as np
 6from pprint import pprint
 7
 8scale = 50
 9n, m = scale, scale * 2
10sparsity = 0.2
11rng = np.random.default_rng(0)
12
13# %% Build the problem (CasADi code, independent of alpaqa)
14
15# Quadratic loss plus l1-regularization
16# minimize  ½‖Ax - b‖² + λ‖x‖₁
17
18A = rng.uniform(-1, 1, (m, n))
19x_exact = rng.uniform(0, 1, n)
20x_exact[rng.uniform(0, 1, n) > sparsity] = 0
21b = A @ x_exact + rng.normal(0, 0.1, m)
22λ = 0.025 * m
23
24# Symbolic solution
25x = cs.MX.sym("x", n)
26# Objective function is squared norm of Ax - b
27f = 0.5 * cs.sumsqr(A @ x - b)
28
29# %% Generate and compile C code for the objective and constraints using alpaqa
30
31# Compile and load the problem
32problem = (
33    pa.minimize(f, x)
34    .with_l1_regularizer(λ)
35).compile(sym=cs.MX.sym)
36
37# %% Solve the problem using alpaqa's PANOC solver
38
39direction = pa.LBFGSDirection({"memory": scale})
40solver = pa.PANOCSolver({"print_interval": 10}, direction)
41# Add evaluation counters to the problem
42cnt = pa.problem_with_counters(problem)
43# Solve
44sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
45
46# %% Print the results
47
48final_f = problem.eval_f(sol)
49print()
50pprint(stats)
51print()
52print("Evaluations:")
53print(cnt.evaluations)
54print(f"Cost:          {final_f + stats['final_h']}")
55print(f"Loss:          {final_f}")
56print(f"Regularizer:   {stats['final_h']}")
57print(f"FP Residual:   {stats['ε']}")
58print(f"Run time:      {stats['elapsed_time']}")
59print(stats["status"])
60
61# %% Plot the results
62
63import matplotlib.pyplot as plt
64
65plt.figure(figsize=(8, 5))
66plt.plot(x_exact, ".-", label="True solution")
67plt.plot(sol, ".-", label="Estimated solution")
68plt.legend()
69plt.title("PANOC lasso example")
70plt.tight_layout()
71plt.show()