Nuclear norm¶
In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.
\[\begin{aligned}
& \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_*
\end{aligned}\]
Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.
The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).
1# %% alpaqa nuclear norm example
2
3from pprint import pprint
4
5import jax
6import jax.numpy as jnp
7import jax.numpy.linalg as jla
8from jax import grad, jit, random
9
10import alpaqa as pa
11
12jax.config.update("jax_enable_x64", True)
13
14# %% Define the problem functions
15
16# Quadratic loss plus nuclear norm regularization
17#
18# minimize ½‖vec(AX - B)‖² + λ nucl(X)
19
20
21# Returns the loss function with constant A and B
22def loss(A, B):
23 def _loss(X):
24 err = (A @ X - B).ravel()
25 return 0.5 * jnp.dot(err, err)
26
27 return _loss
28
29
30class MyProblem(pa.UnconstrProblem):
31 def __init__(self, A, B, λ):
32 self.rows, self.cols = A.shape[1], B.shape[1]
33 super().__init__(self.rows * self.cols)
34 f = loss(A, B)
35 self.jit_loss = jit(f)
36 self.jit_grad_loss = jit(grad(f))
37 self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
38
39 def eval_objective(self, x): # Cost function
40 # Important: use consistent order when reshaping or raveling!
41 X = jnp.reshape(x, (self.rows, self.cols), order="F")
42 return self.jit_loss(X)
43
44 def eval_objective_gradient(self, x, grad_f): # Gradient of the cost
45 X = jnp.reshape(x, (self.rows, self.cols), order="F")
46 grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
47
48 def eval_proximal_gradient_step(self, γ, x, grad, x_hat, p):
49 # use the prox_step helper function to carry out a generalized
50 # forward-backward step. This assumes Fortran order (column major),
51 # so we have to use order="F" for all reshape/ravel calls
52 return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
53
54
55# %% Generate some data
56
57m, n = 30, 30
58r = m // 5
59
60key = random.PRNGKey(0)
61key, *subkeys = random.split(key, 6)
62# Random rank-r data matrix X = UV, then add some noise
63U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
64V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
65B_rand = 0.01 * random.normal(subkeys[2], (m, n))
66A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
67X_true = U_true @ V_true
68B = A @ X_true + B_rand # Add noise
69
70print("cond(A) =", jla.cond(A))
71print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
72
73prob = MyProblem(A, B, λ=1 * m)
74
75# %% Solve the problem using alpaqa's PANOC solver
76
77opts = {
78 "max_iter": 2000,
79 "stop_crit": pa.FPRNorm,
80 "quadratic_upperbound_tolerance_factor": 1e-14,
81}
82direction = pa.LBFGSDirection({"memory": 100})
83# direction = pa.AndersonDirection({"memory": 5})
84solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
85
86# Add callback to the solver
87residuals = []
88
89
90def callback(it: pa.PANOCProgressInfo):
91 residuals.append(it.ε)
92
93
94solver.set_progress_callback(callback)
95
96# Add evaluation counters to the problem
97cnt = pa.problem_with_counters(prob)
98# Solve the problem
99sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
100X_sol = jnp.reshape(sol, (m, n), order="F")
101
102# %% Print the results
103
104final_f = prob.eval_objective(sol)
105print()
106pprint(stats)
107print()
108print("Evaluations:")
109print(cnt.evaluations)
110print(f"Cost: {final_f + stats['final_h']}")
111print(f"Loss: {final_f}")
112print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
113print(f"Regularizer: {stats['final_h']}")
114print(f"Rank: {jla.matrix_rank(X_sol)}")
115print(f"True rank: {r}")
116print(f"FP Residual: {stats['ε']}")
117print(f"Run time: {stats['elapsed_time']}")
118print(stats["status"])
119
120# %% Plot the results
121
122import matplotlib.pyplot as plt
123
124plt.figure()
125plt.semilogy(residuals, ".-")
126plt.title("PANOC nuclear norm example: residuals")
127plt.xlabel("Iteration")
128plt.tight_layout()
129
130if m * n <= 5000:
131 plt.figure(figsize=(8, 5))
132 plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
133 plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
134 plt.plot(B.ravel(), "*-", label="Constant $ B $")
135 plt.legend()
136 plt.title("PANOC nuclear norm example: solution")
137 plt.tight_layout()
138
139plt.show()