Nuclear norm¶
In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.
\[\begin{aligned}
& \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_*
\end{aligned}\]
Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.
The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).
1# %% alpaqa nuclear norm example
2
3import alpaqa as pa
4import jax
5import jax.numpy as jnp
6import jax.numpy.linalg as jla
7from jax import grad, jit
8from jax import random
9from pprint import pprint
10
11jax.config.update("jax_enable_x64", True)
12
13# %% Define the problem functions
14
15# Quadratic loss plus nuclear norm regularization
16#
17# minimize ½‖vec(AX - B)‖² + λ nucl(X)
18
19
20# Returns the loss function with constant A and B
21def loss(A, B):
22 def _loss(X):
23 err = (A @ X - B).ravel()
24 return 0.5 * jnp.dot(err, err)
25
26 return _loss
27
28
29class MyProblem(pa.UnconstrProblem):
30 def __init__(self, A, B, λ):
31 self.rows, self.cols = A.shape[1], B.shape[1]
32 super().__init__(self.rows * self.cols)
33 f = loss(A, B)
34 self.jit_loss = jit(f)
35 self.jit_grad_loss = jit(grad(f))
36 self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
37
38 def eval_objective(self, x): # Cost function
39 # Important: use consistent order when reshaping or raveling!
40 X = jnp.reshape(x, (self.rows, self.cols), order="F")
41 return self.jit_loss(X)
42
43 def eval_objective_gradient(self, x, grad_f): # Gradient of the cost
44 X = jnp.reshape(x, (self.rows, self.cols), order="F")
45 grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
46
47 def eval_proximal_gradient_step(self, γ, x, grad, x_hat, p):
48 # use the prox_step helper function to carry out a generalized
49 # forward-backward step. This assumes Fortran order (column major),
50 # so we have to use order="F" for all reshape/ravel calls
51 return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
52
53
54# %% Generate some data
55
56m, n = 30, 30
57r = m // 5
58
59key = random.PRNGKey(0)
60key, *subkeys = random.split(key, 6)
61# Random rank-r data matrix X = UV, then add some noise
62U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
63V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
64B_rand = 0.01 * random.normal(subkeys[2], (m, n))
65A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
66X_true = U_true @ V_true
67B = A @ X_true + B_rand # Add noise
68
69print("cond(A) =", jla.cond(A))
70print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
71
72prob = MyProblem(A, B, λ=1 * m)
73
74# %% Solve the problem using alpaqa's PANOC solver
75
76opts = {
77 "max_iter": 2000,
78 "stop_crit": pa.FPRNorm,
79 "quadratic_upperbound_tolerance_factor": 1e-14,
80}
81direction = pa.LBFGSDirection({"memory": 100})
82# direction = pa.AndersonDirection({"memory": 5})
83solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
84
85# Add callback to the solver
86residuals = []
87
88
89def callback(it: pa.PANOCProgressInfo):
90 residuals.append(it.ε)
91
92
93solver.set_progress_callback(callback)
94
95# Add evaluation counters to the problem
96cnt = pa.problem_with_counters(prob)
97# Solve the problem
98sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
99X_sol = jnp.reshape(sol, (m, n), order="F")
100
101# %% Print the results
102
103final_f = prob.eval_objective(sol)
104print()
105pprint(stats)
106print()
107print("Evaluations:")
108print(cnt.evaluations)
109print(f"Cost: {final_f + stats['final_h']}")
110print(f"Loss: {final_f}")
111print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
112print(f"Regularizer: {stats['final_h']}")
113print(f"Rank: {jla.matrix_rank(X_sol)}")
114print(f"True rank: {r}")
115print(f"FP Residual: {stats['ε']}")
116print(f"Run time: {stats['elapsed_time']}")
117print(stats["status"])
118
119# %% Plot the results
120
121import matplotlib.pyplot as plt
122
123plt.figure()
124plt.semilogy(residuals, ".-")
125plt.title("PANOC nuclear norm example: residuals")
126plt.xlabel("Iteration")
127plt.tight_layout()
128
129if m * n <= 5000:
130 plt.figure(figsize=(8, 5))
131 plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
132 plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
133 plt.plot(B.ravel(), "*-", label="Constant $ B $")
134 plt.legend()
135 plt.title("PANOC nuclear norm example: solution")
136 plt.tight_layout()
137
138plt.show()