Nuclear norm

In this example, we use the PANOC solver to solve a problem with a nuclear norm as regularizer, in order to promote low-rank solutions.

\[\begin{aligned} & \minimize_X && \tfrac12 \normsq{AX - B}_F + \lambda \norm{X}_* \end{aligned}\]

Here, \(X \in \R^{m\times n}\), with a rank \(r\) that’s much smaller than its dimensions.

The JAX package is used to compute gradients (although we could easily do without, since the loss is quadratic).

  1# %% alpaqa nuclear norm example
  2
  3from pprint import pprint
  4
  5import jax
  6import jax.numpy as jnp
  7import jax.numpy.linalg as jla
  8from jax import grad, jit, random
  9
 10import alpaqa as pa
 11
 12jax.config.update("jax_enable_x64", True)
 13
 14# %% Define the problem functions
 15
 16# Quadratic loss plus nuclear norm regularization
 17#
 18# minimize  ½‖vec(AX - B)‖² + λ nucl(X)
 19
 20
 21# Returns the loss function with constant A and B
 22def loss(A, B):
 23    def _loss(X):
 24        err = (A @ X - B).ravel()
 25        return 0.5 * jnp.dot(err, err)
 26
 27    return _loss
 28
 29
 30class MyProblem(pa.UnconstrProblem):
 31    def __init__(self, A, B, λ):
 32        self.rows, self.cols = A.shape[1], B.shape[1]
 33        super().__init__(self.rows * self.cols)
 34        f = loss(A, B)
 35        self.jit_loss = jit(f)
 36        self.jit_grad_loss = jit(grad(f))
 37        self.reg = pa.functions.NuclearNorm(λ, self.rows, self.cols)
 38
 39    def eval_objective(self, x):  # Cost function
 40        # Important: use consistent order when reshaping or raveling!
 41        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 42        return self.jit_loss(X)
 43
 44    def eval_objective_gradient(self, x, grad_f):  # Gradient of the cost
 45        X = jnp.reshape(x, (self.rows, self.cols), order="F")
 46        grad_f[:] = self.jit_grad_loss(X).ravel(order="F")
 47
 48    def eval_proximal_gradient_step(self, γ, x, grad, x_hat, p):
 49        # use the prox_step helper function to carry out a generalized
 50        # forward-backward step. This assumes Fortran order (column major),
 51        # so we have to use order="F" for all reshape/ravel calls
 52        return pa.prox_step(self.reg, x, grad, x_hat, p, γ, -γ)
 53
 54
 55# %% Generate some data
 56
 57m, n = 30, 30
 58r = m // 5
 59
 60key = random.PRNGKey(0)
 61key, *subkeys = random.split(key, 6)
 62# Random rank-r data matrix X = UV, then add some noise
 63U_true = random.uniform(subkeys[0], (m, r), minval=-1, maxval=1)
 64V_true = random.uniform(subkeys[1], (r, n), minval=-1, maxval=1)
 65B_rand = 0.01 * random.normal(subkeys[2], (m, n))
 66A = random.uniform(subkeys[3], (m, m), minval=-1, maxval=1)
 67X_true = U_true @ V_true
 68B = A @ X_true + B_rand  # Add noise
 69
 70print("cond(A) =", jla.cond(A))
 71print("inf(B) =", jla.norm(B_rand.ravel(), jnp.inf))
 72
 73prob = MyProblem(A, B, λ=1 * m)
 74
 75# %% Solve the problem using alpaqa's PANOC solver
 76
 77opts = {
 78    "max_iter": 2000,
 79    "stop_crit": pa.FPRNorm,
 80    "quadratic_upperbound_tolerance_factor": 1e-14,
 81}
 82direction = pa.LBFGSDirection({"memory": 100})
 83# direction = pa.AndersonDirection({"memory": 5})
 84solver = pa.PANOCSolver({"print_interval": 10} | opts, direction)
 85
 86# Add callback to the solver
 87residuals = []
 88
 89
 90def callback(it: pa.PANOCProgressInfo):
 91    residuals.append(it.ε)
 92
 93
 94solver.set_progress_callback(callback)
 95
 96# Add evaluation counters to the problem
 97cnt = pa.problem_with_counters(prob)
 98# Solve the problem
 99sol, stats = solver(cnt.problem, {"tolerance": 1e-10})
100X_sol = jnp.reshape(sol, (m, n), order="F")
101
102# %% Print the results
103
104final_f = prob.eval_objective(sol)
105print()
106pprint(stats)
107print()
108print("Evaluations:")
109print(cnt.evaluations)
110print(f"Cost:          {final_f + stats['final_h']}")
111print(f"Loss:          {final_f}")
112print(f"Inf norm loss: {jla.norm((X_sol - X_true).ravel(), jnp.inf)}")
113print(f"Regularizer:   {stats['final_h']}")
114print(f"Rank:          {jla.matrix_rank(X_sol)}")
115print(f"True rank:     {r}")
116print(f"FP Residual:   {stats['ε']}")
117print(f"Run time:      {stats['elapsed_time']}")
118print(stats["status"])
119
120# %% Plot the results
121
122import matplotlib.pyplot as plt
123
124plt.figure()
125plt.semilogy(residuals, ".-")
126plt.title("PANOC nuclear norm example: residuals")
127plt.xlabel("Iteration")
128plt.tight_layout()
129
130if m * n <= 5000:
131    plt.figure(figsize=(8, 5))
132    plt.plot((A @ X_sol).ravel(), ".-", label="Estimated solution $ A\\tilde X $")
133    plt.plot((A @ X_true).ravel(), "x:", label="True solution $ AX $")
134    plt.plot(B.ravel(), "*-", label="Constant $ B $")
135    plt.legend()
136    plt.title("PANOC nuclear norm example: solution")
137    plt.tight_layout()
138
139plt.show()