Note

The original gist can be found at: https://gist.github.com/da1d0470d6fb54c63e6a913c1ef67a9e

pymc3-demo.ipynb

This a demo that shows how dense mass matrix adaptation can help PyMC3’s performance in some cases. See this blog post for a more detailed discussion.

First, let’s see what the PyMC3’s performance is on an uncorrelated Gaussian:

import time
import pymc3 as pm

ndim = 5

with pm.Model() as simple_model:
    pm.Normal("x", shape=(ndim,))

strt = time.time()
with simple_model:
    simple_trace = pm.sample(draws=3000, tune=3000, random_seed=42)

    # About half the time is spent in tuning so correct for that
    simple_time = 0.5*(time.time() - strt)

stats = pm.summary(simple_trace)
simple_time_per_eff = simple_time / stats.n_eff.min()
print("time per effective sample: {0:.5f} ms".format(simple_time_per_eff * 1000))
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [x]
Sampling 2 chains, 0 divergences: 100%|██████████| 12000/12000 [00:04<00:00, 2851.12draws/s]
time per effective sample: 0.31932 ms

As discussed in the blog post, PyMC3 doesn’t do so well if there are correlations. But we can use the QuadPotentialFullAdapt potential to get nearly the same performance as above:

import numpy as np

# Generate a random positive definite matrix
np.random.seed(42)
L = np.random.randn(ndim, ndim)
L[np.diag_indices_from(L)] = 0.1*np.exp(L[np.diag_indices_from(L)])
L[np.triu_indices_from(L, 1)] = 0.0
cov = np.dot(L, L.T)

with pm.Model() as model:
    pm.MvNormal("x", mu=np.zeros(ndim), chol=L, shape=(ndim,))

    # *** This is the new part ***
    potential = pm.step_methods.hmc.quadpotential.QuadPotentialFullAdapt(
        model.ndim, np.zeros(model.ndim))
    step = pm.NUTS(model=model, potential=potential)
    # *** end new part ***

    strt = time.time()
    full_adapt_trace = pm.sample(draws=10000, tune=5000, random_seed=42, step=step)
    full_adapt_time = 0.5 * (time.time() - strt)

stats = pm.summary(full_adapt_trace)
full_adapt_time_per_eff = full_adapt_time / stats.n_eff.min()
print("time per effective sample: {0:.5f} ms".format(full_adapt_time_per_eff * 1000))
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [x]
Sampling 2 chains, 0 divergences: 100%|██████████| 30000/30000 [00:21<00:00, 1391.33draws/s]
time per effective sample: 0.30707 ms