author: Garrett Smith title: Parameter recovery with a hierarchical model –-

Parameter recovery test using a hierarchical model

The goal of this script is to test the recovery of parameters from a hierarchical model with a fpdistribution likelihood.

Generating the true data

First, we set up the transition rate matrices for the first-passage time distribution we want to fit.

#T = 4*[-1.0 0 0; 1 -1 1; 0 1 -2]
T = 4*[-1 1.0 0; 1 -2 1; 0 1 -2]
A = 4*[0 0 1.0]
p0 = [1.0, 0, 0]
3-element Vector{Float64}:
 1.0
 0.0
 0.0

Scaling the transition rate matrices by τ = 2.5 should give mean first-passage times of around 400ms. Generating and fitting the paramters (τ and the separate τᵢ) will be done on the log scale and then exponentiated in order to keep the transition rates positive.

nparticipants = 50
true_tau = 2.5
true_sd = 0.2
true_tau_i = exp.(rand(Normal(0, true_sd), nparticipants));

The data will be saved in wide format: Each participant's data corresponds to a row, and each column is a data point.

Specifying the model

Now, we can write the full model including the likelihood. Note that we're using a non-centered parameterization. This is because pilot simulations suggested that sampling was biased in the centered parameterization.

# Switch to param = exp(tau) + exp(tau_i). This will prevent really small params b/c multiplication. 
@model function mod(y, Tmat=T, Amat=A, p0vec=p0)
    np, nd = size(y)
    # Priors
    # Using the non-centered parameterization for τ
    τ ~ Normal()
    τ̂ = 1 + 0.1*τ  # Corresponds to Normal(1, 0.1)
    sd ~ Exponential(0.25)
    τᵢ ~ filldist(Normal(), np)
    τ̂ᵢ = sd .* τᵢ  # Corresponds to MvNormal(0, sd)
    # Likelihood
    mult = exp.(τ̂ .+ τ̂ᵢ)
    y ~ filldist(arraydist([fpdistribution(mult[p]*Tmat, mult[p]*Amat, p0vec) for p in 1:np]), nd)
    return τ̂, τ̂ᵢ, mult
end
mod (generic function with 5 methods)

Sampling

Here, we'll use the NUTS sampler with a burnin of 100 samples and an acceptance rate of 0.65 posterior. We'll use four chains of 1000 samples each. Make sure to execute this script with julia -t 4 HierarchicalParameterRecovery.jl.

#posterior = sample(mod(data), NUTS(250, 0.65), MCMCThreads(), 1000, 4)
posterior = sample(mod(data), NUTS(250, 0.7), 1000);
#posterior_centered = sample(mod_centered(data), NUTS(100, 0.65), 1000);
#posterior_noncentered_tau = sample(mod_noncentered_tau(data), NUTS(100, 0.65), 1000);
#posterior_noncentered_tau_i = sample(mod_noncentered_tau_i(data), NUTS(100, 0.65), 1000);

Evaluating parameter recovery

First, we summarize the chains:

posterior
Chains MCMC chain (1000×64×1 Array{Float64, 3}):

Iterations        = 251:1:1250
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 178.72 seconds
Compute duration  = 178.72 seconds
parameters        = τ, sd, τᵢ[1], τᵢ[2], τᵢ[3], τᵢ[4], τᵢ[5], τᵢ[6], τᵢ[7],
 τᵢ[8], τᵢ[9], τᵢ[10], τᵢ[11], τᵢ[12], τᵢ[13], τᵢ[14], τᵢ[15], τᵢ[16], τᵢ[1
7], τᵢ[18], τᵢ[19], τᵢ[20], τᵢ[21], τᵢ[22], τᵢ[23], τᵢ[24], τᵢ[25], τᵢ[26],
 τᵢ[27], τᵢ[28], τᵢ[29], τᵢ[30], τᵢ[31], τᵢ[32], τᵢ[33], τᵢ[34], τᵢ[35], τᵢ
[36], τᵢ[37], τᵢ[38], τᵢ[39], τᵢ[40], τᵢ[41], τᵢ[42], τᵢ[43], τᵢ[44], τᵢ[45
], τᵢ[46], τᵢ[47], τᵢ[48], τᵢ[49], τᵢ[50]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, h
amiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, 
tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat
    ⋯
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64
    ⋯

           τ   -0.7442    0.4553     0.0144    0.0238    316.8694    0.9991
    ⋯
          sd    0.2374    0.0538     0.0017    0.0029    398.4406    0.9990
    ⋯
       τᵢ[1]   -0.9415    0.7164     0.0227    0.0218    994.0778    0.9991
    ⋯
       τᵢ[2]    0.3490    0.7894     0.0250    0.0197   1349.0116    0.9996
    ⋯
       τᵢ[3]    0.5201    0.7431     0.0235    0.0226   1250.5871    0.9991
    ⋯
       τᵢ[4]    0.4726    0.7791     0.0246    0.0242   1008.9107    1.0006
    ⋯
       τᵢ[5]   -0.4701    0.7294     0.0231    0.0197   1978.9287    0.9990
    ⋯
       τᵢ[6]   -1.0603    0.7125     0.0225    0.0222   1159.4274    0.9995
    ⋯
       τᵢ[7]   -0.1375    0.7535     0.0238    0.0247   1012.2482    0.9991
    ⋯
       τᵢ[8]   -0.6295    0.7316     0.0231    0.0198   1250.2784    0.9994
    ⋯
       τᵢ[9]    1.0526    0.7401     0.0234    0.0201   1378.9551    0.9993
    ⋯
      τᵢ[10]   -0.2141    0.7879     0.0249    0.0228   1210.5560    0.9998
    ⋯
      τᵢ[11]   -0.0125    0.7856     0.0248    0.0214    791.1852    0.9992
    ⋯
      τᵢ[12]   -0.5175    0.7570     0.0239    0.0205   1256.1725    0.9992
    ⋯
      τᵢ[13]   -0.6380    0.7425     0.0235    0.0234   1025.6297    1.0019
    ⋯
      τᵢ[14]   -0.4129    0.7522     0.0238    0.0214    986.6031    0.9990
    ⋯
      τᵢ[15]   -0.6625    0.7582     0.0240    0.0276   1173.4401    0.9990
    ⋯
      ⋮           ⋮         ⋮         ⋮          ⋮          ⋮          ⋮   
    ⋱
                                                    1 column and 35 rows om
itted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           τ   -1.6643   -1.0467   -0.7308   -0.4321    0.2060
          sd    0.1341    0.2026    0.2363    0.2721    0.3500
       τᵢ[1]   -2.3378   -1.4284   -0.9079   -0.4509    0.4462
       τᵢ[2]   -1.1857   -0.1474    0.3521    0.8555    1.9542
       τᵢ[3]   -0.9300    0.0231    0.5213    1.0259    1.9631
       τᵢ[4]   -1.0176   -0.0829    0.4789    0.9884    2.0412
       τᵢ[5]   -1.9371   -0.9491   -0.4642    0.0242    0.9772
       τᵢ[6]   -2.5182   -1.5295   -1.0533   -0.6082    0.2898
       τᵢ[7]   -1.5248   -0.6553   -0.1586    0.3721    1.3880
       τᵢ[8]   -2.0842   -1.1045   -0.6365   -0.1271    0.8049
       τᵢ[9]   -0.4966    0.5819    1.0523    1.5221    2.5336
      τᵢ[10]   -1.7941   -0.7078   -0.2226    0.3031    1.2916
      τᵢ[11]   -1.5853   -0.5313   -0.0123    0.5043    1.5465
      τᵢ[12]   -1.8783   -1.0141   -0.5442   -0.0220    0.9733
      τᵢ[13]   -2.1333   -1.1272   -0.6287   -0.1420    0.8235
      τᵢ[14]   -1.8759   -0.9211   -0.4014    0.1069    0.9702
      τᵢ[15]   -2.1244   -1.1479   -0.6613   -0.1553    0.8299
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                 35 rows omitted

Let's also look at the Gelman-Rubin statistic for the chains:

#gelmandiag(posterior)

# #' And the centered version:
#posterior_centered

# #' τ non-centered, τᵢ centered:
#posterior_noncentered_tau

# #' τ centered, τᵢ non-centered
#posterior_noncentered_tau_i

And plot histograms of the parameters on the millisecond scale:

Posterior distribution of τ. Vertical line shows the true value
Posterior distribution of the SD. Vertical lines show the true value.
Posterior distributions of the τᵢ. Vertical lines show the true value.
Posterior distributions of τ*τᵢ. Vertical lines show the true value.

If the posterior contains the true values of the parameters, we can say the parameters were recovered successfully.