Differential geometry for machine learning

Created by: Roger Grosse
Intended for: machine learning researchers

About this roadmap

Most of Metacademy is geared towards learning things at a comparable level to a university course. I'm writing this roadmap as kind of an experiment, to see if Metacademy can be used to convey cutting-edge research topics, things which most experts in the field aren't already familiar with.

Accordingly, it is more advanced and specialized than the other Metacademy roadmaps. If you're looking for a broader understanding of the field, the Bayesian machine learning and deep learning roadmaps are of more general interest. Or, if you're just starting out, check out "Level-up your machine learning."

This roadmap covers a bunch of neat connections between differential geometry and machine learning. As far as I know, it's pretty hard to find them in a textbook or a course, since the number of people who have the requisite background in both fields is fairly small. Teaching basic differential geometry concepts in an ML course would be too much of a detour. But if you already have the ML background, you'll find that the number of additional differential geometry concepts you need to learn is pretty minimal -- only a fraction of a semester course.

If you are new to Metacademy, you can find a bit more about the structure and motivation here. Links to Metacademy concepts are shown in red; these will give you a full learning plan for the concept, assuming only high school calculus. This roadmap itself just provides enough context and motivation to understand why you might want to learn about the topics; to learn them for real, you will need to follow the links and do the readings. These learning plans automatically get updated as new information is added to Metacademy. External links are shown in green; you're more or less on your own here, though we try to fill in background links where we can. There's no need to go through this roadmap linearly; you can follow whatever you need or find interesting. The learning plans will fill in the background.

This roadmap covers a variety of models and algorithms from machine learning, all of which you can implement without understanding the underlying differential geometry. But this background gives additional insights into the algorithms, and hopefully makes some of the mysterious-looking formulas appear more intuitive. It may also suggest ways to generalize these algorithms which wouldn't be apparent if you've just memorized the formuals.

The basics

The central object of study in differential geometry is the differentiable manifold. A manifold is essentially a space you can cover with coordinate charts, which are invertible, continuous mappings to some subset of R^n. If the charts are chosen such that the mapping between any pair of charts is differentiable, then it's a differentiable manifold.

Why do we care about manifolds? Often, we work with objects such as functions or probability distributions, which we need to parameterize in order to do computations on them. For instance, we might parameterize a Gaussian distribution in terms of its mean and standard deviation. But often there are multiple alternative parameterizations; for instance, Gaussians are often parameterized in information form. In other cases, there might not even be an obvious parameterization.

In these ambiguous cases, we'd often like to make sure that the algorithm we use doesn't depend on some arbitrary parameterization. This is often a problem in practice. For instance, gradient descent is not invariant to parameterization, and much effort has gone into designing clever reparameterizations that make it more efficient.

Differential geometry gives us a way to construct natural quantities, i.e. ones which are parameterization invariant. In general, one can often verify naturality on a case-by-case basis, by crunching through tedious change-of-basis formulas. In fact, this is what is done to check some of the basic objects in differential geometry. But this is tedious, and natural objects are often highly non-obvious.

What differential geometry gives us is a set of high-level abstractions which allow us to construct natural objects directly. These include vector fields, tensor fields, Riemannian metrics, and differential forms, all discussed below. If you build an object out of these primitives, it's automatically natural.

Basically everything in differential geometry depends on the tangent bundle, the cotangent_bundle, and tensor fields, so you should learn about these first. The biggest hurdle to clear is wrapping your head around the tangent bundle -- it's not obvious how to define the "tangent space" at a point when the manifold isn't embedded in some Euclidean space. There are various equivalent definitions which can be a bit abstract and non-intuitive. But after that, the definitions of cotangent spaces and tensors follow pretty easily.

In machine learning, we often work with families of probability distributions. We can view a family of distributions as a manifold, called a statistical manifold, which gives a way of constructing natural objects.

Fisher metric

We often need a notion of distance between two points on a manifold. Unfortunately, there's no unique way to do this if all we have is a manifold. We need to assign an additional structure called a Riemannian metric, which (despite the name) is really an inner product on the tangent space at each point. Recall that inner products (e.g. the dot product) let us define notions like orthogonality, angles, and the length of a vector. Riemannian metrics let us define these things on tangent spaces.

The Riemannian metric which is typically used for statistical manifolds is the Fisher metric. Its coordinate representation is simply the Fisher information matrix. What on earth does Fisher information have to do with distance? Well, KL divergence is the most commonly used measure of dissimilarity between distributions (even though it isn't actually a distance metric). The Fisher information matrix is simply the Hessian of KL divergence at the point where two distributions are equal. Therefore, it gives a quadratic form which acts as a squared distance between two similar distributions -- exactly what we'd want for a Riemannian metric. I have a blog post which explains this in more detail.

Natural gradient

Probably the most common application of the Fisher metric in machine learning is the natural gradient, an analogue of the gradient which is invariant to parameterization. This is a useful property to have for models like neural nets where the relationship between the parameters and the distribution can be very complex, making ordinary gradient descent updates unstable.

In differential geometry terms, what we normally call the gradient is really the differential. The differential is a covector. A small change to the parameters, on the other hand, is a vector. Trying to equate the two is a type error, since they're completely different mathematical objects. But the Riemannian metric gives a way of converting a covector into a vector, known as "raising indices." The natural gradient is just the differential with its index raised by the Fisher metric. Since it's constructed from natural objects, it's automatically invariant to parameterization.

Cramer-Rao bound

The Fisher metric also gives a nice intuition for the Cramer-Rao bound, which I used to find pretty mysterious. What it's really saying is that two distributions need to be sufficiently dissimilar for you to be able to distinguish them. When you estimate a distribution from data, you're likely to confuse it with something else within its Fisher ball.

Jeffreys prior

In Bayesian parameter estimation, you face the question of how to choose a prior distribution. If you have no knowledge specific to the problem, you might want to choose an uninformative prior. The Jeffreys prior is one such choice, motivated by invariance to parameterization. (Unless you believe your parameters are meaningful, you probably don't want your prior to depend on the parameterization.)

Here's another heuristic motivation for the Jeffreys prior. If you wanted to put an uninformative prior on a Euclidean space, you might choose the uniform distribution because it assigns equal probabilities to equal volumes. On manifolds, volumes are computed by integrating a volume form. There's a natural way to construct a volume form from a Riemannian metric, and the Jeffreys prior is simply the volume form constructed from the Riemannian metric. Hence, it assigns equal probabilities to equal volumes. (This automatically explains why it's invariant to parameterization.)

Exponential families

Exponential families are a class of probability distributions that includes many widely used ones such as the Gaussian, Bernoulli, multinomial, gamma, and Poisson distributions. What's special about exponential families is that the distributions are log-linear in a vector of sufficient statistics. Every exponential family has two canonical parameterizations: the natural parameters, and the moments, or expected sufficient statistics. (For instance, in a multivariate Gaussian distribution, the moments correspond to the first and second moments of the distribution, and the natural parameters are the information form representation.) Maximum likelihood has a particularly elegant form for exponential families: simply set the model moments equal to the empirical moments.

Interestingly, Markov random fields (MRFs) are themselves a kind of exponential family. For discrete MRFs, the natural parameters are the log potentials, and the moments are the marginals. Much algorithmic research in MRFs is essentially a matter of converting between these two representations: inference starts with the natural parameters and computes the moments, whereas parameter learning starts with a set of empirical moments (computed from the data) and finds a set of natural parameters which match them. Both of these problems are intractable in the worst case, which shows that converting between parameterizations isn't always as easy as the simple examples like Gaussians and multinomials would suggest.

What's special about natural parameters and moments? Might there be other equally good coordinate systems? In differential geometry terms, natural parameters and moments are both flat coordinate systems. What this means is that there is a natural identification between the tangent spaces at different points. This isn't true for just any manifold: in general tangent spaces at different points aren't directly relatable, and the closest thing to an equivalence is parallel transport, which depends on the path taken. For flat spaces, you can just equate the coordinate representations of the vectors at two points. In a flat coordinate system, straight lines also have a natural interpretation as geodesics on the manifolds; this is discussed in more detail in the next section.

But natural parameters and moments aren't just flat spaces individually; they are dually flat, which is a very special kind of structure. Essentially it means that vectors in the natural parameters correspond to covectors in the moments, and vice versa. This leads to a lot of rather surprising identities, some of which I've used in my own work. The geometry of exponential families is discussed in Chapters 2 and 3 of Amari.

The other example of dually flat spaces (that I'm aware of) is the Legendre transform, which is discussed in Section 8.1 of Amari. (In a sense, this is really the same example, since the moments are the Legendre transform of the natural parameters.)

Geodesics

A lot of algorithms require us to interpolate between two probability distributions, i.e. define a sequence of distributions that bridge from one to the other. Probably the most famous example is annealed importance sampling (AIS), but others include parallel tempering, tempered transitions, and path sampling.

The most obvious kind of interpolation is a straight line in the model parameters. But unless the parameters have some meaningful relationship to the model's predictions, it's not clear why averaging them should give anything sensible. I wouldn't expect averaging the parameters of a feed-forward neural net to be very meaningful.

On manifolds, you can define a natural analogue of straight lines called geodesics -- roughly, curves which keep heading in the "same direction." Defining which directions are the same requires more structure than just a Riemannian metric -- it requires an affine connection.

As discussed in Chaper 2 of Amari, there is a natural family of connections on statistical manifolds called the alpha family, which is parameterized by a scalar alpha. For exponential families, if you plug in alpha = 1, the geodesics correspond to straight lines in the natural parameters. Hence, averaging natural parameters actually has a natural meaning (independent of the parameterization). Note that averaging natural parameters also corresponds to taking geometric averages of the distributions.

In my paper "Annealing between distributions by averaging moments," we look at various alternatives for interpolating between exponential family distributions. (Here is some background for the paper.) While geometric averages seem pretty intuitive, we show that it has certain pathologies which can seriously degrade performance. For instance, the geometric average of the two black Gaussians is the Gaussian shown in red; this seems like an odd way to interpolate.

Geometric averages path

In the paper, we tried to avoid differential geometry terminology for readability, but I'll discuss the connections here. We show that, under certain idealized assumptions, you can measure the asymptotic performance of AIS in terms of a functional on the path known as the energy. Choosing an annealing schedule corresponds to moving along the path at varying speeds. The optimal (energy-minimizing) schedule covers equal distance in equal time, and the energy is the square of the Riemannian path length. (See Riemannian metrics.) The Riemannian path length (and hence the energy under the optimal schedule) can be visualized in terms of the number of Fisher balls the path crosses. Here is a visualization for a univariate Gaussian parameterized in terms of mean and standard deviation:

Visualization of Riemannian path lengths

The energy-minimizing paths are the geodesics under the Levi-Civita connetion; this is shown in green. (The Levi-Citivia connection is also the alpha connection for alpha = 0.) Unfortunately, finding these geodesics seems to be very hard, likely much harder than the problem AIS was originally intended to solve (partition function estimation).

But there's yet another interesting kind of geodesic for exponential families, corresponding to alpha = 1, which consists of averaging the moments of the distributions. We figured, if this is a natural path, why not try using it in AIS? We found that it generally gives pretty sensible paths (shown in blue in the above figures), and seems to outperform geometric averages for estimating partition functions of RBMs. (The downside is that it is more expensive than geometric averages, since it reqiures MRF parameter estimation. But since the parameters don't need to be very accurate, you can use cheap approximations and still get good estimates faster than if you just used geometric averages.)

Here's the moment averages path for the same multivariate Gaussians as above. Notice how it expands the variance in the direction connecting the two means, which increases the overlap between successive distributions:

Moment averages path

Incidentally, Theorem 2 of the paper is the most surprising mathematical result I've ever encountered in my research. It shows that the geometric and moment average paths, different as they might be, always have exactly the same energy! The proof is very short, but I still have no intuition for why it should be true. There has to be something deep going on here...

Hamiltonian flows

As I mentioned, gradient descent can be unstable because it depends on the parameterization, and natural gradient is one way around this. But another trick to speed up gradient descent is to use momentum. This is popular because it's very cheap and simple to implement.

But in Bayesian machine learning, you often need to sample from a posterior rather than maximize a function. Metropolis-Hastings is a very general technique for getting approximate samples from distributions. However, it requires specifying a proposal distribution, and the most naive choices can be extremely inefficient. One of the most successful M-H algorithms is Hamiltonian Monte Carlo (HMC), where the proposal distributions correspond to a numerical simulation of Hamiltonian dynamics -- essentially gradient descent with momentum. HMC is the workhorse behind Stan, a probabilistic programming language. The best reference on HMC is probably Radford Neal's tutorial.

Ordinarily in M-H, there's a tradeoff where small steps have high acceptance probabilities but don't move very far, and large steps have low acceptance probabilities. HMC proposes very large steps, so we might ordinarily expect it to suffer from extremely small acceptance probabilities. Surprisingly, it can be shown that the acceptance probability is close to 1 if the dynamics are simulated to high enough precision -- this is the insight that makes it work. This is a consequence of two facts: that Hamiltonian dynamics conserves energy, and that it preserves volume in phase space.

Similarly to the parameter invariance results discussed previously, these two facts can be verified by inspection, but they also follow from higher-level abstractions. In particular, suppose we have a manifold M which gives the position of a particle. The cotangent bundle T*M can be viewed as a manifold in its own right. In particular, it is a symplectic manifold. The canonical coordinate system for T*M is a combination of position coordinates q, which identify the point in M, and momentum coordinates p, which identify a cotangent vector at q.

Suppose we have a function f on T*M, which we'll refer to as the Hamiltonian. (In the context of HMC, this corresponds to the negative log probability.) Recall that the differential of f is a covector field, so in order to get an update direction, we need to raise its index. When we raise it using the symplectic form, we get what's called the Hamiltonian vector field. The flow for this vector field, the Hamiltonian flow, corresponds to Hamiltonian dynamics. Interestingly, raising the index using the symplectic form is very different from raising it using a Riemannian metric; whereas the latter gives a descent direction, the former gives a direction which preserves f. I.e., the Hamiltonian flow moves along the level sets of f.

The fact that Hamiltonian flows preserve the Hamiltonian corresponds to the first property, conservation of energy. Similarly, it preserves volume, because it preserves the symplectic form and the volume form can be constructed from the symplectic form. Since these properties hold exactly for Hamiltonian flows, and HMC approximates Hamiltonian flows, we'd expect them to hold approximately in the context of HMC. (In fact, volume preservation holds exactly, so it's only energy conservation that's approximate.)

All of this is explained in this paper by Michael Betancourt.

While HMC is powerful, it still suffers from its dependence on the parameterization. Riemannian manifold HMC is an algorithm that combines the insights of both natural gradient and HMC.