Nuclear norm

In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.

\[\begin{aligned} & \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_* \end{aligned}\]

Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.

The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).

  1# %% alpaqa nuclear norm example
  2
  3import alpaqa as pa
  4import jax
  5import jax.numpy as jnp
  6import jax.numpy.linalg as jla
  7from jax import grad, jit
  8from jax import random
  9from pprint import pprint
 10
 11jax.config.update("jax_enable_x64", True)
 12
 13# %% Define the problem functions
 14
 15# Quadratic loss plus nuclear norm regularization
 16#
 17# minimize  ½‖vec(AX - B)‖² + λ nucl(X)
 18
 19
 20# Returns the loss function with constant A and B
 21def loss(A, B):
 22    def _loss(X):
 23        err = (A @ X - B).ravel()
 24        return 0.5 * jnp.dot(err, err)
 25
 26    return _loss
 27
 28
 29class MyProblem(pa.UnconstrProblem):
 30    def __init__(self, A, B, λ):
 31        self.rows, self.cols = A.shape[1], B.shape[1]
 32        super().__init__(self.rows * self.cols)
 33        f = loss(A, B)
 34        self.jit_loss = jit(f)
 35        self.jit_grad_loss = jit(grad(f))
 36        self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
 37
 38    def eval_objective(self, x):  # Cost function
 39        # Important: use consistent order when reshaping or raveling!
 40        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 41        return self.jit_loss(X)
 42
 43    def eval_objective_gradient(self, x, grad_f):  # Gradient of the cost
 44        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 45        grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
 46
 47    def eval_proximal_gradient_step(self, γ, x, grad, x_hat, p):
 48        # use the prox_step helper function to carry out a generalized
 49        # forward-backward step. This assumes Fortran order (column major),
 50        # so we have to use order="F" for all reshape/ravel calls
 51        return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
 52
 53
 54# %% Generate some data
 55
 56m, n = 30, 30
 57r = m // 5
 58
 59key = random.PRNGKey(0)
 60key, *subkeys = random.split(key, 6)
 61# Random rank-r data matrix X = UV, then add some noise
 62U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
 63V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
 64B_rand = 0.01 * random.normal(subkeys[2], (m, n))
 65A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
 66X_true = U_true @ V_true
 67B = A @ X_true + B_rand  # Add noise
 68
 69print("cond(A) =", jla.cond(A))
 70print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
 71
 72prob = MyProblem(A, B, λ=1 * m)
 73
 74# %% Solve the problem using alpaqa's PANOC solver
 75
 76opts = {
 77    "max_iter": 2000,
 78    "stop_crit": pa.FPRNorm,
 79    "quadratic_upperbound_tolerance_factor": 1e-14,
 80}
 81direction = pa.LBFGSDirection({"memory": 100})
 82# direction = pa.AndersonDirection({"memory": 5})
 83solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
 84
 85# Add callback to the solver
 86residuals = []
 87
 88
 89def callback(it: pa.PANOCProgressInfo):
 90    residuals.append(it.ε)
 91
 92
 93solver.set_progress_callback(callback)
 94
 95# Add evaluation counters to the problem
 96cnt = pa.problem_with_counters(prob)
 97# Solve the problem
 98sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
 99X_sol = jnp.reshape(sol, (m, n), order="F")
100
101# %% Print the results
102
103final_f = prob.eval_objective(sol)
104print()
105pprint(stats)
106print()
107print("Evaluations:")
108print(cnt.evaluations)
109print(f"Cost:          {final_f + stats['final_h']}")
110print(f"Loss:          {final_f}")
111print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
112print(f"Regularizer:   {stats['final_h']}")
113print(f"Rank:          {jla.matrix_rank(X_sol)}")
114print(f"True rank:     {r}")
115print(f"FP Residual:   {stats['ε']}")
116print(f"Run time:      {stats['elapsed_time']}")
117print(stats["status"])
118
119# %% Plot the results
120
121import matplotlib.pyplot as plt
122
123plt.figure()
124plt.semilogy(residuals, ".-")
125plt.title("PANOC nuclear norm example: residuals")
126plt.xlabel("Iteration")
127plt.tight_layout()
128
129if m * n <= 5000:
130    plt.figure(figsize=(8, 5))
131    plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
132    plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
133    plt.plot(B.ravel(), "*-", label="Constant $ B $")
134    plt.legend()
135    plt.title("PANOC nuclear norm example: solution")
136    plt.tight_layout()
137
138plt.show()