author: Garrett Smith title: Parameter recovery with a hierarchical model –-
The goal of this script is to test the recovery of parameters from a hierarchical model with a fpdistribution
likelihood.
First, we set up the transition rate matrices for the first-passage time distribution we want to fit.
#T = 4*[-1.0 0 0; 1 -1 1; 0 1 -2] T = 4*[-1 1.0 0; 1 -2 1; 0 1 -2] A = 4*[0 0 1.0] p0 = [1.0, 0, 0]
3-element Vector{Float64}: 1.0 0.0 0.0
Scaling the transition rate matrices by τ = 2.5
should give mean first-passage times of around 400ms. Generating and fitting the paramters (τ and the separate τᵢ) will be done on the log scale and then exponentiated in order to keep the transition rates positive.
nparticipants = 50 true_tau = 2.5 true_sd = 0.2 true_tau_i = exp.(rand(Normal(0, true_sd), nparticipants));
The data will be saved in wide format: Each participant's data corresponds to a row, and each column is a data point.
Now, we can write the full model including the likelihood. Note that we're using a non-centered parameterization. This is because pilot simulations suggested that sampling was biased in the centered parameterization.
# Switch to param = exp(tau) + exp(tau_i). This will prevent really small params b/c multiplication. @model function mod(y, Tmat=T, Amat=A, p0vec=p0) np, nd = size(y) # Priors # Using the non-centered parameterization for τ τ ~ Normal() τ̂ = 1 + 0.1*τ # Corresponds to Normal(1, 0.1) sd ~ Exponential(0.25) τᵢ ~ filldist(Normal(), np) τ̂ᵢ = sd .* τᵢ # Corresponds to MvNormal(0, sd) # Likelihood mult = exp.(τ̂ .+ τ̂ᵢ) y ~ filldist(arraydist([fpdistribution(mult[p]*Tmat, mult[p]*Amat, p0vec) for p in 1:np]), nd) return τ̂, τ̂ᵢ, mult end
mod (generic function with 5 methods)
Here, we'll use the NUTS sampler with a burnin of 100 samples and an acceptance rate of 0.65 posterior. We'll use four chains of 1000 samples each. Make sure to execute this script with julia -t 4 HierarchicalParameterRecovery.jl
.
#posterior = sample(mod(data), NUTS(250, 0.65), MCMCThreads(), 1000, 4) posterior = sample(mod(data), NUTS(250, 0.7), 1000); #posterior_centered = sample(mod_centered(data), NUTS(100, 0.65), 1000); #posterior_noncentered_tau = sample(mod_noncentered_tau(data), NUTS(100, 0.65), 1000); #posterior_noncentered_tau_i = sample(mod_noncentered_tau_i(data), NUTS(100, 0.65), 1000);
First, we summarize the chains:
posterior
Chains MCMC chain (1000×64×1 Array{Float64, 3}): Iterations = 251:1:1250 Number of chains = 1 Samples per chain = 1000 Wall duration = 178.72 seconds Compute duration = 178.72 seconds parameters = τ, sd, τᵢ[1], τᵢ[2], τᵢ[3], τᵢ[4], τᵢ[5], τᵢ[6], τᵢ[7], τᵢ[8], τᵢ[9], τᵢ[10], τᵢ[11], τᵢ[12], τᵢ[13], τᵢ[14], τᵢ[15], τᵢ[16], τᵢ[1 7], τᵢ[18], τᵢ[19], τᵢ[20], τᵢ[21], τᵢ[22], τᵢ[23], τᵢ[24], τᵢ[25], τᵢ[26], τᵢ[27], τᵢ[28], τᵢ[29], τᵢ[30], τᵢ[31], τᵢ[32], τᵢ[33], τᵢ[34], τᵢ[35], τᵢ [36], τᵢ[37], τᵢ[38], τᵢ[39], τᵢ[40], τᵢ[41], τᵢ[42], τᵢ[43], τᵢ[44], τᵢ[45 ], τᵢ[46], τᵢ[47], τᵢ[48], τᵢ[49], τᵢ[50] internals = lp, n_steps, is_accept, acceptance_rate, log_density, h amiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size Summary Statistics parameters mean std naive_se mcse ess rhat ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ τ -0.7442 0.4553 0.0144 0.0238 316.8694 0.9991 ⋯ sd 0.2374 0.0538 0.0017 0.0029 398.4406 0.9990 ⋯ τᵢ[1] -0.9415 0.7164 0.0227 0.0218 994.0778 0.9991 ⋯ τᵢ[2] 0.3490 0.7894 0.0250 0.0197 1349.0116 0.9996 ⋯ τᵢ[3] 0.5201 0.7431 0.0235 0.0226 1250.5871 0.9991 ⋯ τᵢ[4] 0.4726 0.7791 0.0246 0.0242 1008.9107 1.0006 ⋯ τᵢ[5] -0.4701 0.7294 0.0231 0.0197 1978.9287 0.9990 ⋯ τᵢ[6] -1.0603 0.7125 0.0225 0.0222 1159.4274 0.9995 ⋯ τᵢ[7] -0.1375 0.7535 0.0238 0.0247 1012.2482 0.9991 ⋯ τᵢ[8] -0.6295 0.7316 0.0231 0.0198 1250.2784 0.9994 ⋯ τᵢ[9] 1.0526 0.7401 0.0234 0.0201 1378.9551 0.9993 ⋯ τᵢ[10] -0.2141 0.7879 0.0249 0.0228 1210.5560 0.9998 ⋯ τᵢ[11] -0.0125 0.7856 0.0248 0.0214 791.1852 0.9992 ⋯ τᵢ[12] -0.5175 0.7570 0.0239 0.0205 1256.1725 0.9992 ⋯ τᵢ[13] -0.6380 0.7425 0.0235 0.0234 1025.6297 1.0019 ⋯ τᵢ[14] -0.4129 0.7522 0.0238 0.0214 986.6031 0.9990 ⋯ τᵢ[15] -0.6625 0.7582 0.0240 0.0276 1173.4401 0.9990 ⋯ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱ 1 column and 35 rows om itted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 τ -1.6643 -1.0467 -0.7308 -0.4321 0.2060 sd 0.1341 0.2026 0.2363 0.2721 0.3500 τᵢ[1] -2.3378 -1.4284 -0.9079 -0.4509 0.4462 τᵢ[2] -1.1857 -0.1474 0.3521 0.8555 1.9542 τᵢ[3] -0.9300 0.0231 0.5213 1.0259 1.9631 τᵢ[4] -1.0176 -0.0829 0.4789 0.9884 2.0412 τᵢ[5] -1.9371 -0.9491 -0.4642 0.0242 0.9772 τᵢ[6] -2.5182 -1.5295 -1.0533 -0.6082 0.2898 τᵢ[7] -1.5248 -0.6553 -0.1586 0.3721 1.3880 τᵢ[8] -2.0842 -1.1045 -0.6365 -0.1271 0.8049 τᵢ[9] -0.4966 0.5819 1.0523 1.5221 2.5336 τᵢ[10] -1.7941 -0.7078 -0.2226 0.3031 1.2916 τᵢ[11] -1.5853 -0.5313 -0.0123 0.5043 1.5465 τᵢ[12] -1.8783 -1.0141 -0.5442 -0.0220 0.9733 τᵢ[13] -2.1333 -1.1272 -0.6287 -0.1420 0.8235 τᵢ[14] -1.8759 -0.9211 -0.4014 0.1069 0.9702 τᵢ[15] -2.1244 -1.1479 -0.6613 -0.1553 0.8299 ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ 35 rows omitted
Let's also look at the Gelman-Rubin statistic for the chains:
#gelmandiag(posterior) # #' And the centered version: #posterior_centered # #' τ non-centered, τᵢ centered: #posterior_noncentered_tau # #' τ centered, τᵢ non-centered #posterior_noncentered_tau_i
And plot histograms of the parameters on the millisecond scale:
If the posterior contains the true values of the parameters, we can say the parameters were recovered successfully.