Differential geometry in machine learning

Created by: Roger Grosse
Intended for: machine learning researchers

The basics

The central object of study in differential geometry is the differentiable manifold. A manifold is essentially a space you can cover with coordinate charts, which are invertible, continuous mappings to some subset of R^n. If the charts are chosen such that the mapping between any pair of charts is differentiable, then it's a differentiable manifold.

Why do we care about manifolds? Often, we work with objects such as functions or probability distributions, which we need to parameterize in order to do computations on them. For instance, we might parameterize a Gaussian distribution in terms of its mean and standard deviation. But often there are multiple alternative parameterizations; for instance, Gaussians are often parameterized in information form. In other cases, there might not even be an obvious parameterization.

In these ambiguous cases, we'd often like to make sure that the algorithm we use doesn't depend on some arbitrary parameterization. This is often a problem in practice. For instance, gradient descent is not invariant to parameterization, and much effort has gone into designing clever reparameterizations that make it more efficient.

Differential geometry gives us a way to construct natural quantities, i.e. ones which are parameterization invariant. In general, one can often verify naturality on a case-by-case basis, by crunching through tedious change-of-basis formulas. In fact, this is what is done to check some of the basic objects in differential geometry. But this is tedious, and natural objects are often highly non-obvious.

What differential geometry gives us is a set of high-level abstractions which allow us to construct natural objects directly. These include vector fields, tensor fields, Riemannian metrics, and differential forms, all discussed below. If you build an object out of these primitives, it's automatically natural.

Basically everything in differential geometry depends on the tangent bundle, the cotangent_bundle, and tensor fields, so you should learn about these first. The biggest hurdle to clear is wrapping your head around the tangent bundle -- it's not obvious how to define the "tangent space" at a point when the manifold isn't embedded in some Euclidean space. There are various equivalent definitions which can be a bit abstract and non-intuitive. But after that, the definitions of cotangent spaces and tensors follow pretty easily.

In machine learning, we often work with families of probability distributions. We can view a family of distributions as a manifold, called a statistical manifold, which gives a way of constructing natural objects.

Fisher metric

We often need a notion of distance between two points on a manifold. Unfortunately, there's no unique way to do this if all we have is a manifold. We need to assign an additional structure called a Riemannian metric, which (despite the name) is really an inner product on the tangent space at each point. Recall that inner products (e.g. the dot product) let us define notions like orthogonality, angles, and the length of a vector. Riemannian metrics let us define these things on tangent spaces.

The Riemannian metric which is typically used for statistical manifolds is the Fisher metric. Its coordinate representation is simply the Fisher information matrix. What on earth does Fisher information have to do with distance? Well, KL divergence is the most commonly used measure of dissimilarity between distributions (even though it isn't actually a distance metric). The Fisher information matrix is simply the Hessian of KL divergence at the point where two distributions are equal. Therefore, it gives a quadratic form which acts as a squared distance between two similar distributions -- exactly what we'd want for a Riemannian metric. I have a blog post which explains this in more detail.

Natural gradient

Probably the most common application of the Fisher metric in machine learning is the natural gradient, an analogue of the gradient which is invariant to parameterization. This is a useful property to have for models like neural nets where the relationship between the parameters and the distribution can be very complex, making ordinary gradient descent updates unstable.

In differential geometry terms, what we normally call the gradient is really the differential. The differential is a covector. A small change to the parameters, on the other hand, is a vector. Trying to equate the two is a type error, since they're completely different mathematical objects. But the Riemannian metric gives a way of converting a covector into a vector, known as "raising indices." The natural gradient is just the differential with its index raised by the Fisher metric. Since it's constructed from natural objects, it's automatically invariant to parameterization.

Cramer-Rao bound

The Fisher metric also gives a nice intuition for the Cramer-Rao bound, which I used to find pretty mysterious. What it's really saying is that two distributions need to be sufficiently dissimilar for you to be able to distinguish them. When you estimate a distribution from data, you're likely to confuse it with something else within its Fisher ball.

Jeffreys prior

In Bayesian parameter estimation, you face the question of how to choose a prior distribution. If you have no knowledge specific to the problem, you might want to choose an uninformative prior. The Jeffreys prior is one such choice, motivated by invariance to parameterization. (Unless you believe your parameters are meaningful, you probably don't want your prior to depend on the parameterization.)

Here's another heuristic motivation for the Jeffreys prior. If you wanted to put an uninformative prior on a Euclidean space, you might choose the uniform distribution because it assigns equal probabilities to equal volumes. On manifolds, volumes are computed by integrating a volume form. There's a natural way to construct a volume form from a Riemannian metric, and the Jeffreys prior is simply the volume form constructed from the Riemannian metric. Hence, it assigns equal probabilities to equal volumes. (This automatically explains why it's invariant to parameterization.)

Hamiltonian flows

As I mentioned, gradient descent can be unstable because it depends on the parameterization, and natural gradient is one way around this. But another trick to speed up gradient descent is to use momentum. This is popular because it's very cheap and simple to implement.

But in Bayesian machine learning, you often need to sample from a posterior rather than maximize a function. Metropolis-Hastings is a very general technique for getting approximate samples from distributions. However, it requires specifying a proposal distribution, and the most naive choices can be extremely inefficient. One of the most successful M-H algorithms is Hamiltonian Monte Carlo (HMC), where the proposal distributions correspond to a numerical simulation of Hamiltonian dynamics -- essentially gradient descent with momentum. HMC is the workhorse behind Stan, a probabilistic programming language. The best reference on HMC is probably Radford Neal's tutorial.

Ordinarily in M-H, there's a tradeoff where small steps have high acceptance probabilities but don't move very far, and large steps have low acceptance probabilities. HMC proposes very large steps, so we might ordinarily expect it to suffer from extremely small acceptance probabilities. Surprisingly, it can be shown that the acceptance probability is close to 1 if the dynamics are simulated to high enough precision -- this is the insight that makes it work. This is a consequence of two facts: that Hamiltonian dynamics conserves energy, and that it preserves volume in phase space.

Similarly to the parameter invariance results discussed previously, these two facts can be verified by inspection, but they also follow from higher-level abstractions. In particular, suppose we have a manifold M which gives the position of a particle. The cotangent bundle TM can be viewed as a manifold in its own right. In particular, it is a symplectic manifold. The canonical coordinate system for TM is a combination of position coordinates q, which identify the point in M, and momentum coordinates p, which identify a cotangent vector at q.

Suppose we have a function f on T*M, which we'll refer to as the Hamiltonian. (In the context of HMC, this corresponds to the negative log probability.) Recall that the differential of f is a covector field, so in order to get an update direction, we need to raise its index. When we raise it using the symplectic form, we get what's called the Hamiltonian vector field. The flow for this vector field, the Hamiltonian flow, corresponds to Hamiltonian dynamics. Interestingly, raising the index using the symplectic form is very different from raising it using a Riemannian metric; whereas the latter gives a descent direction, the former gives a direction which preserves f. I.e., the Hamiltonian flow moves along the level sets of f.

The fact that Hamiltonian flows preserve the Hamiltonian corresponds to the first property, conservation of energy. Similarly, it preserves volume, because it preserves the symplectic form and the volume form can be constructed from the symplectic form. Since these properties hold exactly for Hamiltonian flows, and HMC approximates Hamiltonian flows, we'd expect them to hold approximately in the context of HMC. (In fact, volume preservation holds exactly, so it's only energy conservation that's approximate.)

All of this is explained in this paper by Michael Betancourt.

While HMC is powerful, it still suffers from its dependence on the parameterization. Riemannian manifold HMC is an algorithm that combines the insights of both natural gradient and HMC.