Nuclear norm#

In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.

\[\begin{aligned} & \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_* \end{aligned}\]

Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.

The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).

  1# %% alpaqa nuclear norm example
  2
  3import alpaqa as pa
  4from jax.config import config
  5import jax.numpy as jnp
  6import jax.numpy.linalg as jla
  7from jax import grad, jit
  8from jax import random
  9from pprint import pprint
 10
 11config.update("jax_enable_x64", True)
 12
 13# %% Define the problem functions
 14
 15# Quadratic loss plus nuclear norm regularization
 16#
 17# minimize  ½‖vec(AX - B)‖² + λ nucl(X)
 18
 19
 20# Returns the loss function with constant A and B
 21def loss(A, B):
 22    def _loss(X):
 23        err = (A @ X - B).ravel()
 24        return 0.5 * jnp.dot(err, err)
 25
 26    return _loss
 27
 28
 29class MyProblem(pa.UnconstrProblem):
 30    def __init__(self, A, B, λ):
 31        self.rows, self.cols = A.shape[1], B.shape[1]
 32        super().__init__(self.rows * self.cols)
 33        f = loss(A, B)
 34        self.jit_loss = jit(f)
 35        self.jit_grad_loss = jit(grad(f))
 36        self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
 37
 38    def eval_f(self, x):  # Cost function
 39        # Important: use consistent order when reshaping or raveling!
 40        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 41        return self.jit_loss(X)
 42
 43    def eval_grad_f(self, x, grad_f):  # Gradient of the cost
 44        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 45        grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
 46
 47    def eval_prox_grad_step(self, γ, x, grad, x_hat, p):
 48        # use the prox_step helper function to carry out a generalized
 49        # forward-backward step. This assumes Fortran order (column major),
 50        # so we have to use order="F" for all reshape/ravel calls
 51        return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
 52
 53
 54# %% Generate some data
 55
 56m, n = 30, 30
 57r = m // 5
 58
 59key = random.PRNGKey(0)
 60key, *subkeys = random.split(key, 6)
 61# Random rank-r data matrix X = UV, then add some noise
 62U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
 63V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
 64B_rand = 0.01 * random.normal(subkeys[2], (m, n))
 65A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
 66X_true = U_true @ V_true
 67B = A @ X_true + B_rand  # Add noise
 68
 69print("cond(A) =", jla.cond(A))
 70print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
 71
 72prob = MyProblem(A, B, λ=1 * m)
 73
 74# %% Solve the problem using alpaqa's PANOC solver
 75
 76opts = {
 77    "max_iter": 2000,
 78    "stop_crit": pa.FPRNorm,
 79    "quadratic_upperbound_tolerance_factor": 1e-14,
 80}
 81direction = pa.LBFGSDirection({"memory": 100})
 82# direction = pa.AndersonDirection({"memory": 5})
 83solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
 84
 85# Add callback to the solver
 86residuals = []
 87
 88
 89def callback(it: pa.PANOCProgressInfo):
 90    residuals.append(it.ε)
 91
 92
 93solver.set_progress_callback(callback)
 94
 95# Add evaluation counters to the problem
 96cnt = pa.problem_with_counters(prob)
 97# Solve the problem
 98sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
 99X_sol = jnp.reshape(sol, (m, n), order="F")
100
101# %% Print the results
102
103final_f = prob.eval_f(sol)
104print()
105pprint(stats)
106print()
107print("Evaluations:")
108print(cnt.evaluations)
109print(f"Cost:          {final_f + stats['final_h']}")
110print(f"Loss:          {final_f}")
111print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
112print(f"Regularizer:   {stats['final_h']}")
113print(f"Rank:          {jla.matrix_rank(X_sol)}")
114print(f"True rank:     {r}")
115print(f"FP Residual:   {stats['ε']}")
116print(f"Run time:      {stats['elapsed_time']}")
117print(stats["status"])
118
119# %% Plot the results
120
121import matplotlib.pyplot as plt
122
123plt.figure()
124plt.semilogy(residuals, ".-")
125plt.title("PANOC nuclear norm example: residuals")
126plt.xlabel("Iteration")
127plt.tight_layout()
128
129if m * n <= 5000:
130    plt.figure(figsize=(8, 5))
131    plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
132    plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
133    plt.plot(B.ravel(), "*-", label="Constant $ B $")
134    plt.legend()
135    plt.title("PANOC nuclear norm example: solution")
136    plt.tight_layout()
137
138plt.show()