[Reference] Gradient w.r.t. Indices

November 10, 2013

To optimize indices, we'll need to compute the derivative of the marginal log-likelihood w.r.t. changing indices.

I first tried to derive this using the generalization of the chain rule to matrix expressions (see matrix cookbook, section 2.8.1), but the computation exploded. Since ultimately, the derivative is a simple single-input, single output function, we can use differentials to derive the solution.

Let the marginal likelihood as a function of indices be g(x):

g(x)xi=xi12(yS(I+SK(x)S)1Sy)
Let U=I+SK(x)S, and V=U1. Working inside out, lets find Uxi. U+dU=I+S(K+dK)S=I+SKS+SdKSdU=SdKSU=SKS
Where M is the derivative of the elements of M w.r.t. xi. Next, Vxi, which comes from the matrix cookbook, equation (36). dV=U1dUU1V=U1UU1
Finally, g(x)xi: g+dg=12yS(V+dV)Syg+dg=12ySVSy12ySdVSydg=12ySdVSyg=12ySVSy
Expanding V gives the final formula: g=12ySU1SKSU1Syg=12yMKMyg=12zKz

Here, M=SU1S, (which is symmetric), and z=My.

This equation gives us a single element of the gradient, namely dg(x)/dxi. However, once z is computed, we can reuse it when recomputing (1) for all other xj's. The cost of each subsequent gradient element becomes O(n2), making the total gradient O(n3), which is pretty good. (This assumes the K's can be computed efficiently, which is true; see below.) However, we also observe that K is sparse with size O(n), so we can do sparse multiplication to reduce the running time to linear, and the full gradient takes O(n2), assuming z is precomputed. Cool!

Derivatives of K(x)

First, we'll layout the general form of K', whose elements are the full derivative of the kernel w.r.t. xk.

Kijxk=k(xi,xj)xidxidxk+k(xi,xj)xjdxjdxk

The first term is nonzero only on the i-th row of K', and the second term is nonzero on the i-th column of K'. This suggests the following convenient sparse representation for K'.

Let the vector δi be the vector whose j-th element is k(xi,xj)xi. Using this notation, we can rewrite K as

Kxi=K=C+C

where C=(0δi0).

Below we derive the derivative k(xi,xj)xi for each of the three covariance expresssions.

Cubic covariance

Recall the cubic covariance expression:

k(xi,xj)=(xaxb)x2b/2+x3b/3
Where xb=min(xi,xi) and xa=max(xi,xi).

Taking the derivative w.r.t. (x_i) gives:

k(xi,xj)xi={x2j/2if xi>=xjxixjx2i/2if xi<xi={x2b/2if xi>=xjxaxbx2b/2if xi<xj

Or equivalently

k(xi,xj)xi=xb(xjxb/2)

Linear Covariance

Recall the linear covariance expression:

k(xi,xj)=xixj
The derivative w.r.t. xi is simply xj.

Offset Covariance

Recall the offset covariance expression:

k(xi,xj)=k
The derivative w.r.t. xi is zero.

Implementation

Implemented end-to-end version in kernel/get_model_kernel_derivative.m; see also components in kernel/get_spacial_kernel_derivative.m and kernel/cubic_kernel_derivative.m.

These functions return all of the partial derivatives of the matrix with respect to the first input. The i-th row of the result make up the nonzero values in Kxi. Below is example code that computes all of the partial derivative matrices.

N = 100;
% construct indices
x = linspace(0, 10, N);
% construct derivative rows
d_kernel = get_model_kernel_derivative(...);
d_K = eval_kernel(d_kernel, x, x);
% construct dK/dx_i, for each i = 1..N
d_K_d_x = dcell(1,N);
for i = 1:N
    tmp = sparse(N, N);
    tmp(i,:) = d_K(i,:);
    tmp(:,i) = d_K(i,:)';
    d_K_d_x{i} = tmp;
end

Directional Derivatives

I think we can get directional derivatives of K by taking the weighted sum of partial derivatives, where the weights are the component lengths of the direction vector. I have yet to confirm this beyond a hand-wavy hunch, and in practice, this might not even be needed, since computing the full gradient is so efficient.

Full gradient

As we saw earlier, Kxi is sparse, and has the form in equation (4). We can use the sparsity to ultimately compute the entire gradient in a single matrix multiplication.

First we'll rewrite g in terms if δi

g=12zKz=12zCz+zCz=12{(0zδi0)z+z(0δiz0)}=zi(δiz)

We can generalize this to the entire gradient using matrix operations:

g=z(Δz)

Where Δ is the matrix whose ith row is δi, and denotes element-wise multiplication.

To handle multiple dimensions, simply apply to each dimension independently and sum the results.

Posted by Kyle Simek