Differential geometry for machine learning

Created by: Roger Grosse
Intended for: machine learning researchers

In machine learning, we're often interested in learning a function or a distribution that we can use to making predictions. In order to learn and apply such a function or distribution, we need to somehow represent it, i.e. put it into a form where we can do computations on it. In parametric modeling, we typically do this by choosing a finite set of parameters which fully characterize the predictor. (Nonparametric modeling, where the model is unbounded in size, often requires coming up with a clever representation in terms of the training examples.) Much of the time, the representation is merely a computational tool, and what we ultimately care about is the function or distribution itself.

As we'll see below, the performance of many learning algorithms is very sensitive to the choice of representation. A lot of work has been devoted to coming up with clever reparameterizations which boost the performance of an algorithm. (As a simple example, consider how often you need to recenter your data with zero mean and unit variance!) But this seems odd: if you're working with the same class of predictors, why should the algorithm do something different just because you're describing those predictors differently?

In fact, lots of algorithms have been carefully designed to be independent of the representation. Many of the biggest advances in the last 30 years of machine learning research can be seen as languages for talking directly about the predictors themselves without reference to an underlying parameterization. Examples include kernels, Gaussian processes, Bayesian nonparametrics, and VC dimension.

Differential geometry is all about constructing things which are independent of the representation. You treat the space of objects (e.g. distributions) as a manifold, and describe your algorithm in terms of things that are intrinsic to the manifold itself. While you ultimately need to use some coordinate system to do the actual computations, the higher-level abstractions make it easier to check that the objects you're working with are intrinsically meaningful. This roadmap is intended to highlight some examples of models and algorithms from machine learning which can be interpreted in terms of differential geometry.

Most of the content in this roadmap belongs to information geometry, the study of manifolds of probability distributions. The best reference on this topic is probably Amari and Nagaoka's Methods of Information Geometry. If you're working through that book, you will hopefully still find this roadmap useful for some context and to track down background readings.

About this roadmap

So far, Metacademy has mostly been geared towards learning things at a comparable level to a university course. I'm writing this roadmap as kind of an experiment, to see if Metacademy can be used to convey cutting-edge research topics and deep connections between different fields -- things which most experts in the field aren't already familiar with.

Accordingly, it is more advanced and specialized than the other Metacademy roadmaps. If you're looking for a broader understanding of the field, the Bayesian machine learning and deep learning roadmaps are of more general interest. Or, if you're just starting out, check out "Level-up your machine learning."

This roadmap covers a bunch of neat connections between differential geometry and machine learning. As far as I know, it's pretty hard to find them in a textbook or a course, since the number of people who have the requisite background in both fields is fairly small. Teaching basic differential geometry concepts in an ML course would be too much of a detour. But if you already have the ML background, you'll find that the number of additional differential geometry concepts you need to learn is pretty minimal -- only a fraction of a semester course.

If you are new to Metacademy, you can find a bit more about the structure and motivation here. Links to Metacademy concepts are shown in red; these will give you a full learning plan for the concept, assuming only high school calculus. These learning plans automatically get updated as new information is added to Metacademy. External links are shown in green; you're more or less on your own here, though we try to fill in background links where we can. There's no need to go through this roadmap linearly; you can follow whatever you need or find interesting. The learning plans will fill in the background. This roadmap itself just provides enough context and motivation to understand why you might want to learn about the topics; to learn them for real, you will need to follow the links and do the readings.

All of the algorithms and models covered in this roadmap can be implemented without understanding the underlying differential geometry. But this background gives additional insights into the algorithms, and hopefully makes some of the mysterious-looking formulas appear more intuitive. It may also suggest ways to generalize these algorithms which wouldn't be apparent if you've just memorized the formuals.

Motivation: why representation independence?

Why do we care if an algorithm's behavior depends on the representation of the predictor? In order to motivate this, I'll discuss a simple example, one which has nothing to do with differential geometry: polynomial regression. Suppose we wish to fit a cubic polynomial to predict a scalar target y as a function of a scalar input x. We can do this with linear regression, using the basis function expansion phi(x) = (1, x, x^2, x^3). For a small dataset, we might get results such as the following:

Polynomial regression Polynomial regression, shifted

The two figures show the same data, except that the one on the right has been shifted so that the x-axis goes from 8 to 12 rather than from -2 to 2. Observe that the fitted polynomial is the same (except shifted by the same amount). This happens because polynomial regression chooses the polynomial which minimizes the squared error, and polynomials can be shifted, i.e. replace p(x) by p(x-10), where p denotes the polynomial. Notice that the one on the right has very large coefficients, because this shifting operation tends to blow up the coefficients when it is explicitly expanded out. This has no impact on the algorithm, at least in exact arithmetic. (Numerical stability issues turn out to be very significant, but I'll ignore those issues in this roadmap.)

Linear regression is invariant to parameterization, but the story changes once we try to regularize the model. The most common regularized version of linear regression is ridge regression, where we penalize the squared Euclidean norm of the coefficients. Now observe what happens when we try to fit the same data as above:

Polynomial regression, regularized Polynomial regression, shifted, regularized

Now the fits are completely different! Ridge regression hates large coefficients, so it tries really hard to shrink the coefficients in the second model, even at the expense of predictive accuracy. This is a bit troubling, since regularization is intended to penalize the "complexity" of the predictor, and it doesn't seem like shifting your data should change the relative complexity of different polynomial models.

The problem of representation invariance when fitting smooth functions has been studied pretty extensively. Smoothing splines are a kind of regression model motivated in terms of a variational optimization problem formulated on the functions themselves. More modern approaches include kernel ridge regression and Gaussian process regression. See this interesting paper for a discussion of how Gaussian process regression incorporates a sophisticated form of Occam's Razor which accounts for the complexity of the functions themselves.

In this particular example, you can get rid of the pathology by centering your data at 0. But there aren't any comparably easy fixes for many of the problems discussed later in this roadmap. In fact, if your algorithm is performing badly, there's often little to suggest that the problem lies with the representation.

The basics

The central object of study in differential geometry is the differentiable manifold. A manifold is essentially a space you can cover with coordinate charts, which are invertible, continuous mappings to some subset of R^n. If the charts are chosen such that the mapping between any pair of overlapping charts is differentiable, then it's a differentiable manifold.

Differential geometry gives us a way to construct natural objects, i.e. ones which are intrinsic to the manifold itself. Naturality has a precise meaning in category theory, but effectively it means that they transform in the "right way" when you map them from one space to another. The most important case for us is when the transformation is a change of coordinates: if you compute an object (e.g. a weight update) in some coordinate chart A, and then transform it to another chart B, you get the same result as if you had computed it in B originally. In general, one can often verify naturality on a case-by-case basis, by crunching through tedious change-of-basis formulas. In fact, this is what is done to check some of the basic objects in differential geometry. But this is tedious, and it doesn't suggest how one would come up with the natural objects in the first place.

What differential geometry gives us is a set of high-level abstractions which allow us to construct natural objects directly. These include vector fields, tensor fields, Riemannian metrics, and differential forms, all discussed below. If you build an object out of these primitives in certain well-defined ways, it's automatically natural.

How to start learning about differential geometry? Basically everything depends on the tangent bundle, the cotangent bundle, and tensor fields, so you should learn about these first. The biggest hurdle is making sense of the tangent bundle -- it's not obvious how to define the "tangent space" at a point when the manifold isn't embedded in some Euclidean space. There are several constructions, which can be a bit abstract and non-intuitive. But after that, the definitions of cotangent spaces and tensors follow pretty easily.

In machine learning, we often work with families of probability distributions. We can view a family of distributions as a manifold, called a statistical manifold, which gives a way of constructing natural objects. Most of the remainder of this roadmap focuses on statistical manifolds.

Fisher metric

We often need a notion of distance between two points on a manifold. Unfortunately, there's no unique way to do this if all we have is a manifold. We need to assign an additional structure called a Riemannian metric, which (despite the name) is really an inner product on the tangent space at each point. Recall that inner products (e.g. the dot product) let us define notions like orthogonality, angles, and the length of a vector. Riemannian metrics let us define these things on tangent spaces. As a consequence, they also let us define the length of a path, which is why they are called metrics.

The Riemannian metric which is typically used for statistical manifolds is the Fisher metric. Its coordinate representation is simply the Fisher information matrix. What on earth does Fisher information have to do with distance? Well, KL divergence is the most commonly used measure of dissimilarity between distributions (even though it isn't actually a distance metric). The Fisher information matrix is simply the Hessian of KL divergence at the point where two distributions are equal. Therefore, it gives a quadratic form which acts as a squared distance between two similar distributions -- exactly what we'd want for a Riemannian metric. I wrote a blog post which explains this in more detail.

Natural gradient

Probably the most common application of the Fisher metric in machine learning is the natural gradient, an analogue of the gradient which is independent of the parameterization. This is a useful property to have for models like neural nets where the relationship between the parameters and the distribution can be very complex, making ordinary stochastic gradient descent updates unstable. Essentially, they oscillate back and forth in directions where the distribution changes rapidly, and make only plodding progress in directions where it changes slowly.

In differential geometry terms, what we normally call the gradient is really the differential. (See cotangent bundle.) The differential is a covector. A small update to the parameters, on the other hand, is a vector. Trying to equate the two is a type error, since they're completely different mathematical objects. But the Riemannian metric gives a way of converting a covector into a vector, known as "raising indices." The natural gradient is just the differential with its index raised by the Fisher metric. Since it's constructed from natural objects, it's automatically independent of the parameterization.

Cramer-Rao inequality

One of the first applications of the Fisher metric was the Cramer-Rao inequality. The "Fisher ball" around some distribution p roughly corresponds to the set of distributions which are similar to p, in the sense of having small KL divergence. The Cramer-Rao inequality essentially says that if the data are drawn from p, then no matter how you estimate the distribution, you're probably going to confuse it with something else in the Fisher ball around p.

Jeffreys prior

In Bayesian parameter estimation, you face the question of how to choose a prior distribution. If you have no knowledge specific to the problem, you might want to choose an uninformative prior. The most obvious choice is the uniform distribution, and this is indeed the appropriate prior for location parameters (such as the mean of a Gaussian). But a uniform distribution in one coordinate system could be very much non-uniform in another one. Consider, e.g., the difference between uniform priors over the standard deviation or the precision of a Gaussian distribution. Evidently, these priors are not uninformative!

The intuition behind the uniform distribution can be generalized, though. Consider why it seems like the obvious choice: it assigns equal probabilities to equal volumes. On manifolds, volumes are computed by integrating a volume form. There's a natural way to construct a volume form from a Riemannian metric, and when this construction is applied to the Fisher metric, you arrive at the Jeffreys prior. This prior assigns equal probability to equal volumes on the statistical manifold.

Exponential families

Exponential families are a class of probability distributions that includes many widely used ones such as the Gaussian, Bernoulli, multinomial, gamma, and Poisson distributions. What's special about exponential families is that the distributions are log-linear in a vector of sufficient statistics. Every exponential family has two canonical parameterizations: the natural parameters, and the moments, or expected sufficient statistics. (For instance, in a multivariate Gaussian distribution, the moments correspond to the first and second moments of the distribution, and the natural parameters are the information form representation.) Maximum likelihood has a particularly elegant form for exponential families: simply set the model moments equal to the empirical moments.

Don't be fooled by these simple examples: Markov random fields (MRFs) are themselves a kind of exponential family. For discrete MRFs, the natural parameters are the log potentials, and the moments are the marginals. Much algorithmic research in MRFs is essentially a matter of converting between these two representations: inference starts with the natural parameters and computes the moments, whereas parameter learning starts with a set of empirical moments (computed from the data) and finds a set of natural parameters which match them. Both of these problems are intractable in the worst case, which shows that converting between parameterizations isn't always as easy as the simple examples like Gaussians and multinomials would suggest.

What's special about natural parameters and moments? Might there be other equally good coordinate systems? In differential geometry terms, natural parameters and moments are both flat coordinate systems. What this means is that there is a natural identification between the tangent spaces at different points. This isn't true for just any manifold: in general, tangent spaces at different points aren't directly relatable, and the closest thing to an equivalence is parallel transport, which depends on the path taken. In flat coordinate systems, you can directly equate the coordinate representations of vectors at two points. Straight lines also have a natural interpretation as geodesics on the manifolds; this is discussed in more detail in the next section.

But natural parameters and moments aren't just flat spaces individually; they are dually flat, which is a very special kind of structure. Essentially it means that vectors in the natural parameters correspond to covectors in the moments, and vice versa. This leads to a lot of rather surprising identities, some of which I've used in my own work. The geometry of exponential families is discussed in Chapters 2 and 3 of Amari.

The other example of dually flat spaces (that I'm aware of) is the Legendre transform, which is discussed in Section 8.1 of Amari. (In a sense, this is really the same example, since the moments are the Legendre transform of the natural parameters.)

A lot of machine learning problems involve simultaneously fitting a set of related distributions; for instance, topic modeling involves estimating distributions over words for each document in a collection. This can be done by jointly representing all the parameter vectors as a big matrix, and trying to factorize that matrix. But depending on which parameters you choose, you get a different model. Probabilistic latent semantic analysis (pLSA) and its Bayesian successor, latent Dirichlet allocation (LDA), factorize the moments representation, and have the interpretation that different topics combine additively in a document. Exponential family PCA, on the other hand, factorizes the matrix of natural parameters; because the probabilities are log-linear in this representation, different topics combine multiplicatively.

Geodesics

A lot of algorithms require us to interpolate between two probability distributions, i.e. define a sequence of distributions connecting the two such that each distribution is similar to its successor. Probably the most famous example is annealed importance sampling (AIS), but others include parallel tempering, tempered transitions, and path sampling.

The most obvious kind of interpolation is a straight line in the space of model parameters. But unless the parameters have some meaningful relationship to the model's predictions, it's not clear why averaging them should give anything sensible. I wouldn't expect averaging the parameters of a feed-forward neural net to be very meaningful.

On manifolds, you can define a natural analogue of straight lines called geodesics -- roughly, curves which keep heading in the "same direction." Defining which directions are the same requires more structure than just a Riemannian metric -- it requires an affine connection.

As discussed in Chaper 2 of Amari, there is a natural family of connections on statistical manifolds called the alpha family, which is parameterized by a scalar alpha. For exponential families, if you plug in alpha = 1, the geodesics correspond to straight lines in the natural parameters. Hence, averaging natural parameters is intrinsically meaningful. In fact, it has a nice interpretation as taking geometric averages of the distributions.

In my paper "Annealing between distributions by averaging moments," we look at various schemes for interpolating between exponential family distributions, with the aim of improving AIS. (Here is some background for the paper.) In the paper, we tried to avoid differential geometry terminology for readability, but since this is a differential geometry roadmap, I'll focus on these connections.

While geometric averages seem pretty intuitive, we show that it has certain pathologies which can seriously degrade performance. For instance, the geometric averages path connecting the following two black Gaussians is the path shown in red, where the dots represent equal time intervals. (The red Gaussian is the geometric mean.) This seems like an odd way to interpolate.

Geometric averages path

We show that, under certain idealized assumptions, you can measure the asymptotic performance of AIS in terms of a functional on the path known as the energy. Choosing an annealing schedule corresponds to moving along the path at varying speeds. The optimal (energy-minimizing) schedule covers equal distance in equal time, and the energy is the square of the Riemannian path length. (See Riemannian metrics.) The Riemannian path length (and hence the energy under the optimal schedule) can be visualized in terms of the number of Fisher balls the path crosses. Here is a visualization of path lengths for a univariate Gaussian parameterized in terms of mean and standard deviation:

Visualization of Riemannian path lengths

The energy-minimizing paths are the geodesics under the Levi-Civita connetion; this is shown in green. (The Levi-Civita connection is also the alpha connection for alpha = 0.) Unfortunately, finding these geodesics appears to be hard, likely even harder than the problem AIS was originally intended to solve (partition function estimation).

But there's yet another interesting kind of geodesic for exponential families, corresponding to alpha = -1, which consists of averaging the moments of the distributions. We figured, if this is a natural path, why not try using it in AIS? We found that it generally gives pretty sensible paths (shown in blue in the above figures), and seems to outperform geometric averages for estimating partition functions of RBMs. (The downside is that it is more expensive than geometric averages, since it reqiures MRF parameter estimation. But since the parameters don't need to be very accurate, you can use cheap approximations and still get good estimates faster than if you just used geometric averages.)

Here's the moment averages path for the same multivariate Gaussians as above. Notice how it expands the variance in the direction connecting the two means, which increases the overlap between successive distributions:

Moment averages path

Incidentally, Theorem 2 of the paper is the most surprising mathematical result I've ever encountered in my research. It shows that the geometric and moment average paths, different as they might be, always have exactly the same energy! The proof is very short, but I still have no intuition for why it should be true. There has to be something deep going on here...

Hamiltonian flows

As I mentioned, gradient descent can be unstable because it depends on the parameterization, and natural gradient is one way around this. But natural gradient can be expensive and tricky to implement. A simple and widely used alternative is momentum, where each update is a weighted average of the gradient and the previous update. Effectively, the weight vector is viewed as a particle, and the gradient acts as a "force" which causes it to accelerate in some direction. This method is popular because it requires only a few extra lines of code and adds almost no computational overhead.

In Bayesian machine learning, you often wish to sample from a posterior distribution rather than maximize a function. Markov chain Monte Carlo (MCMC) is a very general class of techniques for drawing approximate samples. Sampling from high-dimensional distributions requires finding the important modes of the distribution, i.e. finding regions of high probability. For this reason, MCMC algorithms are often designed to behave like algorithms from the optimization literature. E.g., annealed importance sampling (as discussed above) was based on simulated annealing.

Hamiltonian Monte Carlo (HMC) is the MCMC analogue of gradient descent with momentum. It alternates between randomly sampling a momentum vector and performing a numerical simulation of Hamiltonian dynamics. HMC is one of the most successful MCMC algorithms, and is the workhorse behind Stan, a probabilistic programming language. The best reference on HMC is probably Radford Neal's tutorial.

Translating an optimization algorithm into the framework of MCMC is rarely straightforward. One major hurdle is to ensure that the dynamics are reversible, i.e. that they satisfy a certain set of equations called detailed balance. This can always be achieved using the Metropolis-Hastings (M-H) algorithm, which accepts or rejects each proposed transition with certain probabilities chosen to achieve reversibility. Unfortunately, unless the proposals are designed very carefully, the sampler will either suffer from extremely small acceptance probabilities or take extremely small steps.

HMC is an M-H algorithm which proposes very large steps, so we might ordinarily expect it to suffer from extremely small acceptance probabilities. But -- and this is the insight that makes the algorithm work -- it can be shown that the acceptance probability is close to 1 if the dynamics are simulated to high enough precision. This result is a consequence of two facts: that Hamiltonian dynamics conserves energy, and that it preserves volume in phase space.

Similarly to the parameter invariance results discussed previously, these two facts can be verified by inspection, but they also follow from higher-level abstractions. In particular, suppose we have a manifold M which gives the position of a particle. The cotangent bundle T*M can be viewed as a manifold in its own right. In general, the cotangent bundle can be parameterized in terms of a vector q which determines the point in M, and a second vector p which identifies a cotangent vector at q. In the context of Hamiltonian dynamics, the cotangent bundle is known as the phase space, and q and p correspond to the position and momentum coordinates, respectively.

When viewed as a manifold, the cotangent bundle can be assigned a particular differential form called a symplectic form. This structure makes it a symplectic manifold.

Suppose we have a function f on T*M, which we'll refer to as the Hamiltonian. In the context of HMC, this corresponds to the sum of a potential energy (the negative log probability) and a kinetic energy. Recall that the differential of f is a covector field, so in order to get an update direction, we need to raise its index. When we raise it using the symplectic form, we get what's called the Hamiltonian vector field. The flow for this vector field, the Hamiltonian flow, corresponds to Hamiltonian dynamics. Interestingly, raising the index using the symplectic form is very different from raising it using a Riemannian metric; whereas the latter gives an ascent direction, the former gives a direction which preserves f. I.e., the Hamiltonian flow moves along the level sets of f.

The fact that Hamiltonian flows preserve the Hamiltonian corresponds to the first property, conservation of energy. Similarly, it preserves volume, because it preserves the symplectic form and the volume form can be constructed from the symplectic form. Since these properties hold exactly for Hamiltonian flows, and HMC approximates Hamiltonian flows, we'd expect them to hold approximately in the context of HMC. (In fact, volume preservation holds exactly, so it's only energy conservation that's approximate.)

All of this is explained in this paper by Michael Betancourt.

While HMC is powerful, it still suffers from its dependence on the parameterization. Is there an analogue which is based on the intrinsic structure of statistical manifolds? Yes there is -- Riemannian manifold HMC is an algorithm that combines the insights of both natural gradient and HMC into a single algorithm.