Loading [MathJax]/jax/output/HTML-CSS/jax.js

Thursday, 5 May 2016

Computing Bayes factors with Stan and a path-sampling method

Stan is a great program for MCMC (or HMC, really). Vehtari et al. explain here how to use Stan to compute WAIC. For the Bayes factor, however, I have not found a method yet, and therefore I would like to demonstrate a possible method here. This will obviously not work well for every model; this is merely an experiment.
Recently, I was really intrigued by a paper by Gelman and Meng, where several methods for computing Bayes factors, or normalizing constants, are explained and connected (even the really bad ones). Here, I will use the path sampling method.
Let us implement a simple model in Stan, for which we can explicitly compute the marginal likelihood. Then we can try to estimate this marginal likelihood with the path-sampling method, and compare it with the exact value.
A very simple model is the 'fair coin' example (taken directly from wikipedia). The Bayes factor between a null-model M0 and a model that incorporates a bias, M1, can be computed directly as the quotient of the marginal likelihoods, and moreover, the null model does not have any parameters.
Let n denote the number of coin tosses, and k the number of 'heads'. Hence, the data D is given by the pair (n,k). Given a prior θBeta(α,β) on the probability of throwing heads, we get the posterior θBeta(α+k,β+nk), and we can compute the marginal likelihood exactly:
p(D|M1)=10p(D|θ)π(θ)dθ =(nk)10θk+α1(1θ)nk+β1dθ=(nk)B(k+α,nk+β),
where B denotes the Beta function. Meanwhile,
p(D|M0)=(nk)(12)k(112)nk.
In this instance of path sampling (and closely following Gelman and Meng), we consider a family of (un-normalized) distributions QT, indexed by a parameter T[0,1], such that Q0(θ)=π(θ) and Q1(θ)=p(D|θ)π(θ). The normalizing constants are denoted by z(T). Notice that z(0)=1 and z(1)=p(D|M1).
Let Θ=[0,1] denote the support of θ. Since
ddTlogz(T)=1z(T)ddTz(T)=1z(T)ddTΘQT(θ)dθ,
we get that
ddTlogz(T)=Θ1z(T)ddTQT(θ)dθ,
and hence
ddTlogz(T)=ΘQT(θ)z(T)ddTlog(QT(θ))dθ.
When we denote by ET the expectation under PT, we get that
ddTlogz(T)=ΘPT(θ)ddTlog(QT(θ))dθ=ET[ddTlog(QT(θ))].
We can think of U(θ,T):=ddTlog(QT(θ)) as 'potential energy', and we get
10ET[U(θ,T)]dT=10ddTlog(z(T))dT =log(z(1))log(z(0))=log(z(1)/z(0))=:λ.
Notice that in our case λ=log(P(D|M1)). We can interpret 10ET[U(θ,T)]dT as the expectation of U over the joint probability density of T (with a uniform prior) and θ:
λ=E[U(θ,T)].
This suggests an estimator ˆλ for λ:
ˆλ=1NNi=1U(θi,Ti),
where (θi,Ti)i is a sample from the joint distribution of θ and T. A way of creating such a sample, is first sampling Ti from its marginal (uniform) distribution, and then sampling θi from PTi. This last step might require some Monte Carlo sampling.
First, we need to choose 1-parameter family of distributions. A simple choice is the geometric path:
QT(θ)=π(θ)1T(π(θ)p(D|θ))T=π(θ)p(D|θ)T.
In this case, the potential energy simply equals ddTlog(QT(θ))=logp(D|θ)

The Stan model

Using the pystan interface, we can implement the model as follows. The most important parts are the parameter T (declared in the data section), and the "generated quantity" U.
## import some modules
import pystan
import scipy.stats as sts
import scipy.special as spl
import numpy as np
import multiprocessing

## define a Stan model
model = """
data {
    int<lower=0> n;
    int<lower=0, upper=n> k;
    real<lower=0> alpha;
    real<lower=0> beta;
    real<lower=0, upper=1> T; // parameter for path sampling
}
parameters {
    real<lower=0, upper=1> theta;
}
model {
    theta ~ beta(alpha, beta);
    increment_log_prob(T*binomial_log(k, n, theta));
    // replaces sampling statement "k ~ binomial(n, theta)"
}
generated quantities {
    real U;
    U <- binomial_log(k, n, theta);
}
"""

## let Stan translate this into C++, and compile...
sm = pystan.StanModel(model_code=model)

A parallel method

We need to generate samples Ti from Uniform(0,1), and then, given Ti we generate a sample θi from PTi. The simplest way is just to make a partition Ti=iN of [0,1] and then for each i=0,,N, use the Stan model with T=Ti. Notice that for each i, we will generate multiple (K, say) samples from PTi. This method lends itself well for multi-processing, as all N+1 Stan sessions can run in parallel.
## choose some parameters
k = 10 ## heads
n = 100 ## coin tosses
alp = 1 ## determines prior for q
bet = 1 ## determines prior for q
K = 100 ## length of each chain
N = 1000 ## number of Ts

## a function that prepares a data dictionary,
## and then runs the Stan model
def runStanModel(T):
    coin_data = {
        'n' : n, 
        'k' : k, 
        'alpha' : alp,
        'beta' : bet,
        'T' : T
    }
    fit = sm.sampling(data=coin_data, iter=2*K, 
                      warmup=K, chains=1) 
    la = fit.extract(permuted=True)
    return la['U'] ## U is a "generated quantity"

## make a partition of [0,1] 
Ts = np.linspace(0, 1, N+1)
## start a worker pool
pool = multiprocessing.Pool(4) ## 4 threads
## for each T in Ts, run the Stan model
Us = np.array(pool.map(runStanModel, Ts))
Let's have a look at the result. Notice that for α=β=1, the marginal likelihood does not depend on k as P(D|M1)=(nk)Γ(k+1)Γ(nk+1)Γ(n+2)=(nk)k!(nk)!(n+1)!=1n+1 We could take for ˆλ the average of all (N+1)K samples, but in my experience, the standard error is more realistic when I only take one sample per Ti.
## take one sample for each T
lamhat = np.mean(Us[:,-1]) 
## we can also compute a standard error!!
se_lamhat = sts.sem(Us[:,-1])

print "extimated lambda = %f +/- %f"%(lamhat, se_lamhat)
print "estimated p(D|M_1) = %f"%np.exp(lamhat)

exactMargLike = spl.beta(k+alp, n-k+bet) * spl.binom(n,k)
exactMargLoglike = np.log(exactMargLike)

print "exact lambda = %f"%exactMargLoglike
print "exact p(D|M_1) = %f"%exactMargLike
In my case, the result is
estimated lambda = -4.724850 +/- 0.340359
estimated p(D|M_1) = 0.008872
exact lambda = -4.615121
exact p(D|M_1) = 0.009901

A serial method

Another method does not use parallel processing, but uses the fact that the distributions PTi and PTi+1 are very similar when TiTi+1=1N is small. When we have a sample from PTi, we can use it as the initial condition for the Stan run with T=Ti+1. We then only need very little burn-in (warm-up) time before we are actually sampling from PTi+1. We can specify the number of independent chains that Stan computes, and also separate initial parameters for each of the chains. Hence, we can take multiple samples PTi as initial choices for the next chain. For this very simple model, this "serial" method is much slower than the parallel method, but my guess is that it could be a lot faster for more complicated models. I hope to prove this claim in a future post.
## choose some parameters
k = 10 ## heads
n = 100 ## coin tosses
alp = 1 ## determines prior for q
bet = 1 ## determines prior for q
K = 100 ## size initial chain
N = 200 ## number of Ts

## initially, do a longer run with T=0
coin_data = {
    'n' : n, 
    'k' : k, 
    'alpha' : alp,
    'beta' : bet,
    'T' : 0
}
fit = sm.sampling(data=coin_data, iter=2*K, 
                  warmup=K, chains=1)
la = fit.extract(permuted=True)

## in stead of length K, 
## now use a much shorter chain (of length L)
L = 10 
chains = 4 
Us = np.zeros(shape=(N+1,L*chains))
Ts = np.linspace(0, 1, N+1)

## now run the 'chain of chains'
for i, Ti in enumerate(Ts):
    coin_data['T'] = Ti ## take another T
    ## take some thetas from the previous sample
    thetas = np.random.choice(la["theta"], chains)
    initial_guesses = [{'theta' : theta} for theta in thetas]
    fit = sm.sampling(data=coin_data, iter=2*L, warmup=L, 
                      chains=chains, init=initial_guesses)
    la = fit.extract(permuted=True)
    Us[i,:] = la['U']
Ok, let us have another look at the result. For this, I used the same code as above:
estimated lambda = -5.277354 +/- 1.120274
estimated p(D|M_1) = 0.005106
exact lambda = -4.615121
exact p(D|M_1) = 0.009901
As you can see, the estimate is less precise, but this is due to the fact that N=200 instead of 1000.
I've written in, and I ran the above code fragments from a Jupyter notebook. As compiling and sampling can take a lot of time, such an interface can be very convenient. Please let me know when this was somehow useful for you or if you have any questions, and also please tell me if I did something stupid...

No comments:

Post a Comment