Differential geometry in machine learning

Created by: Roger Grosse
Intended for: machine learning researchers

The basics

The central object of study in differential geometry is the differentiable manifold. A manifold is essentially a space you can cover with coordinate charts, which are invertible, continuous mappings to some subset of R^n. If the charts are chosen such that the mapping between any pair of charts is differentiable, then it's a differentiable manifold.

Why do we care about manifolds? Often, we work with objects such as functions or probability distributions, which we need to parameterize in order to do computations on them. For instance, we might parameterize a Gaussian distribution in terms of its mean and standard deviation. But often there are multiple alternative parameterizations; for instance, Gaussians are often parameterized in information form. In other cases, there might not even be an obvious parameterization.

In these ambiguous cases, we'd often like to make sure that the algorithm we use doesn't depend on some arbitrary parameterization. This is often a problem in practice. For instance, gradient descent is not invariant to parameterization, and much effort has gone into designing clever reparameterizations that make it more efficient.

Differential geometry gives us a way to construct natural quantities, i.e. ones which are parameterization invariant. In general, one can often verify naturality on a case-by-case basis, by crunching through tedious change-of-basis formulas. In fact, this is what is done to check some of the basic objects in differential geometry. But this is tedious, and natural objects are often highly non-obvious.

What differential geometry gives us is a set of high-level abstractions which allow us to construct natural objects directly. These include vector fields, tensor fields, Riemannian metrics, and differential forms, all discussed below. If you build an object out of these primitives, it's automatically natural.

Basically everything in differential geometry depends on the tangent bundle, the cotangent_bundle, and tensor fields, so you should learn about these first. The biggest hurdle to clear is wrapping your head around the tangent bundle -- it's not obvious how to define the "tangent space" at a point when the manifold isn't embedded in some Euclidean space. There are various equivalent definitions which can be a bit abstract and non-intuitive. But after that, the definitions of cotangent spaces and tensors follow pretty easily.

In machine learning, we often work with families of probability distributions. We can view a family of distributions as a manifold, called a statistical manifold, which gives a way of constructing natural objects.

Fisher metric

We often need a notion of distance between two points on a manifold. Unfortunately, there's no unique way to do this if all we have is a manifold. We need to assign an additional structure called a Riemannian metric, which (despite the name) is really an inner product on the tangent space at each point. Recall that inner products (e.g. the dot product) let us define notions like orthogonality, angles, and the length of a vector. Riemannian metrics let us define these things on tangent spaces.

The Riemannian metric which is typically used for statistical manifolds is the Fisher metric. Its coordinate representation is simply the Fisher information matrix. What on earth does Fisher information have to do with distance? Well, KL divergence is the most commonly used measure of dissimilarity between distributions (even though it isn't actually a distance metric). The Fisher information matrix is simply the Hessian of KL divergence at the point where two distributions are equal. Therefore, it gives a quadratic form which acts as a squared distance between two similar distributions -- exactly what we'd want for a Riemannian metric. I have a blog post which explains this in more detail.

Natural gradient

Probably the most common application of the Fisher metric in machine learning is the natural gradient, an analogue of the gradient which is invariant to parameterization. This is a useful property to have for models like neural nets where the relationship between the parameters and the distribution can be very complex, making ordinary gradient descent updates unstable.

In differential geometry terms, what we normally call the gradient is really the differential. The differential is a covector. A small change to the parameters, on the other hand, is a vector. Trying to equate the two is a type error, since they're completely different mathematical objects. But the Riemannian metric gives a way of converting a covector into a vector, known as "raising indices." The natural gradient is just the differential with its index raised by the Fisher metric. Since it's constructed from natural objects, it's automatically invariant to parameterization.

Cramer-Rao bound

The Fisher metric also gives a nice intuition for the Cramer-Rao bound, which I used to find pretty mysterious. What it's really saying is that two distributions need to be sufficiently dissimilar for you to be able to distinguish them. When you estimate a distribution from data, you're likely to confuse it with something else within its Fisher ball.

Jeffreys prior

In Bayesian parameter estimation, you face the question of how to choose a prior distribution. If you have no knowledge specific to the problem, you might want to choose an uninformative prior. The Jeffreys prior is one such choice, motivated by invariance to parameterization. (Unless you believe your parameters are meaningful, you probably don't want your prior to depend on the parameterization.)

Here's another heuristic motivation for the Jeffreys prior. If you wanted to put an uninformative prior on a Euclidean space, you might choose the uniform distribution because it assigns equal probabilities to equal volumes. On manifolds, volumes are computed by integrating a volume form. There's a natural way to construct a volume form from a Riemannian metric, and the Jeffreys prior is simply the volume form constructed from the Riemannian metric. Hence, it assigns equal probabilities to equal volumes. (This automatically explains why it's invariant to parameterization.)