Stan's ODE solver has a fixed signature. The state, parameters and data has to be given as arrays. This can get confusing during development of a model, since parameters can change to constants (i.e. data), the number of parameters can change and even the state variables could change. Errors can easily creep in since the user has to correctly change all indices in the ODE function, and in (e.g.) the
model
block.
A nice way to solve this is using constant functions. Constants in Stan are functions without an argument that
return the constant value. For instance pi()
returns 3.14159265... The user can define such constants
in the functions
block. Suppose that we want to fit the Lotka-Volterra model to data. The system of ODEs is given by
dxdt=ax−bxy,dydt=cbxy−dy
and so we have a 2-dimenional state, and we need a parameter vector of length 4.
In the
function
block, we will define a function int idx_a() { return 1; }
that returns the index of the parameter a in the parameter vector, and we define similar functions
for the other parameters. The full model can be implemented in Stan as shown below. The data Preys
and Predators
is assumed to be Poisson-distributed with mean Kx and Ky, respectively,
for some large constant K.
I fitted the model to some randomly generated data, which resulted in figure above.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
functions { | |
// parameter indices | |
int idx_a() { return 1; } | |
int idx_b() { return 2; } | |
int idx_c() { return 3; } | |
int idx_d() { return 4; } | |
// state indices | |
int idx_x() { return 1; } | |
int idx_y() { return 2; } | |
// Lotka-Volterra ODEs | |
real[] LV_sys(real t, real[] u, real[] par, real[] real_data, int[] int_data) { | |
real du[2]; | |
// use index functions instead of integer literals | |
du[idx_x()] = par[idx_a()] * u[idx_x()] | |
- par[idx_b()] * u[idx_x()] * u[idx_y()]; | |
du[idx_y()] = par[idx_c()] * par[idx_b()] * u[idx_x()] * u[idx_y()] | |
- par[idx_d()] * u[idx_y()]; | |
return du; | |
} | |
} | |
data { | |
int N; | |
real Times[N]; | |
int Preys[N]; | |
int Predators[N]; | |
real K; | |
} | |
parameters { | |
// initial conditions | |
real<lower=0> x0; | |
real<lower=0> y0; | |
// parameters | |
real<lower=0> a; | |
real<lower=0> b; | |
real<lower=0, upper=1> c; | |
real<lower=0> d; | |
} | |
transformed parameters { | |
real par[4]; | |
real u0[2]; | |
// again use index functions to make a parameter array | |
par[idx_a()] = a; | |
par[idx_b()] = b; | |
par[idx_c()] = c; | |
par[idx_d()] = d; | |
// make an array for the initial condition | |
u0[idx_x()] = x0; | |
u0[idx_y()] = y0; | |
} | |
model { | |
// integrate the ODEs | |
real us[N, 2] = integrate_ode_rk45(LV_sys, u0, 0, Times, par, {0.0}, {0}); | |
// priors on the parameters | |
x0 ~ normal(1, 1); | |
y0 ~ normal(1, 1); | |
a ~ normal(1, 1); | |
b ~ normal(1, 1); | |
c ~ beta(1, 1); | |
d ~ normal(1, 1); | |
// likelihood of the data | |
Preys ~ poisson(to_array_1d(to_vector(us[:, idx_x()]) * K)); | |
Predators ~ poisson(to_array_1d(to_vector(us[:, idx_y()]) * K)); | |
} | |
generated quantities { | |
// export the solution of the ODE | |
real us_hat[N, 2] = integrate_ode_rk45(LV_sys, u0, 0, Times, par, {0.0}, {0}); | |
// and simulate noise | |
real us_sim[N, 2]; | |
for ( i in 1:N ) { | |
for ( j in 1:2 ) { | |
us_sim[i, j] = poisson_rng(us_hat[i, j] * K); | |
} | |
} | |
} |
Of course, this is still a low-dimensional model with a small number of parameters, but I found that even in a simple model, defining parameter indices in this way keeps everything concise.
I used the following Python script to generate the data and interface with Stan.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import pystan | |
from scipy.integrate import solve_ivp | |
from scipy.stats import poisson | |
## choose nice parameter values | |
a = 1 | |
b = 0.2 | |
c = 0.5 | |
d = 0.5 | |
## define the system | |
def LV_sys(t, u): | |
return [a*u[0] - b*u[0]*u[1], c*b*u[0]*u[1] - d*u[1]] | |
## observation times and initial conditions | |
N = 50 | |
Times = np.linspace(1, 25, N) | |
x0 = 1 | |
y0 = 1 | |
u0 = [x0, y0] | |
K = 10 | |
## generate random data | |
sol = solve_ivp(LV_sys, (0, max(Times)), u0, t_eval=Times) | |
Preys = [poisson.rvs(x*K) for x in sol.y[0]] | |
Predators = [poisson.rvs(x*K) for x in sol.y[1]] | |
## compile the Stan model | |
sm = pystan.StanModel(file="lotka-volterra.stan") | |
## prepare a data dictionary and initial parameter values for Stan | |
data_dict = { | |
'N' : N, | |
'Times' : Times, | |
'Preys' : Preys, | |
'Predators' : Predators, | |
'K' : K | |
} | |
def init_dict_gen(): | |
return { | |
'a' : a, | |
'b' : b, | |
'c' : c, | |
'd' : d, | |
'x0' : x0, | |
'y0' : y0 | |
} | |
## sample from posterior | |
sam = sm.sampling(data=data_dict, init=init_dict_gen, thin=10, chains=2, iter=5000) | |
## make a figure with data and fit | |
chain_dict = sam.extract(permuted=True) | |
fig, ax = plt.subplots(1, 1, figsize=(7,5)) | |
ax.scatter(Times, Preys, color='tab:blue', edgecolors='k', zorder=2) | |
ax.scatter(Times, Predators, color='tab:orange', edgecolors='k', zorder=2) | |
pcts = [2.5, 97.5] ## percentiles | |
colors = ['tab:blue', 'tab:orange'] | |
## plot trajectories | |
for j, color in enumerate(colors): | |
range_hat = [K*np.percentile(us, pcts) for us in chain_dict["us_hat"][:,:,j].T] | |
ax.fill_between(Times, *np.array(range_hat).T, color=color, | |
alpha=0.5, linewidth=0) | |
## plot simulations | |
for j, color in enumerate(colors): | |
range_sim = [np.percentile(us, pcts) for us in chain_dict["us_sim"][:,:,j].T] | |
ax.fill_between(Times, *np.array(range_sim).T, color=color, | |
alpha=0.3, linewidth=0) | |
ax.set_ylabel("Prey (blue), Predator (orange)\ndata and fit") | |
ax.set_xlabel("Time") | |
fig.savefig("LV-model-fit.png", bbox_inches='tight', dpi=200) | |
## plot parameter estimates | |
fig, ax = plt.subplots(1, 1, figsize=(7,3)) | |
parnames = ["a", "b", "c", "d", "x0", "y0"] | |
real_par_vals = [a, b, c, d, x0, y0] | |
## make violinplots of estimates | |
pos = range(len(parnames)) | |
ax.violinplot([chain_dict[x] for x in parnames], pos) | |
ax.set_xticks(pos) | |
ax.set_xticklabels(parnames) | |
## plot real parameter values | |
ax.scatter(pos, real_par_vals, color='k') | |
ax.set_ylabel("parameter estimate (blue)\nreal parameter value (black)") | |
ax.set_xlabel("parameter name") | |
fig.savefig("LV-model-estimates.png", bbox_inches='tight', dpi=200) |
The following Figure show the parameter estimates together with the "real" parameter values
No comments:
Post a Comment