Nuclear norm#
In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.
\[\begin{aligned}
& \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_*
\end{aligned}\]
Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.
The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).
1# %% alpaqa nuclear norm example
2
3import alpaqa as pa
4from jax.config import config
5import jax.numpy as jnp
6import jax.numpy.linalg as jla
7from jax import grad, jit
8from jax import random
9from pprint import pprint
10
11config.update("jax_enable_x64", True)
12
13# %% Define the problem functions
14
15# Quadratic loss plus nuclear norm regularization
16#
17# minimize ½‖vec(AX - B)‖² + λ nucl(X)
18
19
20# Returns the loss function with constant A and B
21def loss(A, B):
22 def _loss(X):
23 err = (A @ X - B).ravel()
24 return 0.5 * jnp.dot(err, err)
25
26 return _loss
27
28
29class MyProblem(pa.UnconstrProblem):
30 def __init__(self, A, B, λ):
31 self.rows, self.cols = A.shape[1], B.shape[1]
32 super().__init__(self.rows * self.cols)
33 f = loss(A, B)
34 self.jit_loss = jit(f)
35 self.jit_grad_loss = jit(grad(f))
36 self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
37
38 def eval_f(self, x): # Cost function
39 # Important: use consistent order when reshaping or raveling!
40 X = jnp.reshape(x, (self.rows, self.cols), order="F")
41 return self.jit_loss(X)
42
43 def eval_grad_f(self, x, grad_f): # Gradient of the cost
44 X = jnp.reshape(x, (self.rows, self.cols), order="F")
45 grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
46
47 def eval_prox_grad_step(self, γ, x, grad, x_hat, p):
48 # use the prox_step helper function to carry out a generalized
49 # forward-backward step. This assumes Fortran order (column major),
50 # so we have to use order="F" for all reshape/ravel calls
51 return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
52
53
54# %% Generate some data
55
56m, n = 30, 30
57r = m // 5
58
59key = random.PRNGKey(0)
60key, *subkeys = random.split(key, 6)
61# Random rank-r data matrix X = UV, then add some noise
62U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
63V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
64B_rand = 0.01 * random.normal(subkeys[2], (m, n))
65A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
66X_true = U_true @ V_true
67B = A @ X_true + B_rand # Add noise
68
69print("cond(A) =", jla.cond(A))
70print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
71
72prob = MyProblem(A, B, λ=1 * m)
73
74# %% Solve the problem using alpaqa's PANOC solver
75
76opts = {
77 "max_iter": 2000,
78 "stop_crit": pa.FPRNorm,
79 "quadratic_upperbound_tolerance_factor": 1e-14,
80}
81direction = pa.LBFGSDirection({"memory": 100})
82# direction = pa.AndersonDirection({"memory": 5})
83solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
84
85# Add callback to the solver
86residuals = []
87
88
89def callback(it: pa.PANOCProgressInfo):
90 residuals.append(it.ε)
91
92
93solver.set_progress_callback(callback)
94
95# Add evaluation counters to the problem
96cnt = pa.problem_with_counters(prob)
97# Solve the problem
98sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
99X_sol = jnp.reshape(sol, (m, n), order="F")
100
101# %% Print the results
102
103final_f = prob.eval_f(sol)
104print()
105pprint(stats)
106print()
107print("Evaluations:")
108print(cnt.evaluations)
109print(f"Cost: {final_f + stats['final_h']}")
110print(f"Loss: {final_f}")
111print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
112print(f"Regularizer: {stats['final_h']}")
113print(f"Rank: {jla.matrix_rank(X_sol)}")
114print(f"True rank: {r}")
115print(f"FP Residual: {stats['ε']}")
116print(f"Run time: {stats['elapsed_time']}")
117print(stats["status"])
118
119# %% Plot the results
120
121import matplotlib.pyplot as plt
122
123plt.figure()
124plt.semilogy(residuals, ".-")
125plt.title("PANOC nuclear norm example: residuals")
126plt.xlabel("Iteration")
127plt.tight_layout()
128
129if m * n <= 5000:
130 plt.figure(figsize=(8, 5))
131 plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
132 plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
133 plt.plot(B.ravel(), "*-", label="Constant $ B $")
134 plt.legend()
135 plt.title("PANOC nuclear norm example: solution")
136 plt.tight_layout()
137
138plt.show()