Skip to content

Commit

Permalink
Fix broken case when variables are multi-dimensional and add tests fo…
Browse files Browse the repository at this point in the history
…r AdvancedHMC extension and inference (#89)

Fix #71.

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
sunxd3 and yebai authored Sep 17, 2023
1 parent 35661f6 commit 1f5dd43
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 60 deletions.
8 changes: 8 additions & 0 deletions src/BUGSExamples/BUGSExamples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@ const VOLUME_I = (

const VOLUME_II = (birats=birats, eyes=eyes)

function has_ground_truth(m::Symbol)
if m in union(keys(VOLUME_I), keys(VOLUME_II))
return haskey(getfield(BUGSExamples, m), :reference_results)
else
return false
end
end

end
5 changes: 0 additions & 5 deletions src/BUGSExamples/Backgrounds/rats.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ This example is taken from section 6 of *Gelfand et al. (1990)*, and concerns 30
weights were measured weekly for five weeks. Part of the data is shown below, where $Y_{ij}$ is the
weight of the $i^{th}$ rat measured at age $x_j$.




<center>

$$\text{Weights } Y_{ij} \text{ of rat } i \text{ on day } x_j$$
Expand Down Expand Up @@ -44,5 +41,3 @@ be independent).

$ a_c $, $ \tau_a $, $ b_c $, $ \tau_b $, $ \tau_c $ are given independent "noninformative" priors. Interest particularly focuses on
the intercept at zero time (birth), denoted $ a_0 = a_c - b_c \cdot \bar{x} $.


6 changes: 6 additions & 0 deletions src/BUGSExamples/Volume_I/Equiv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ equiv <- step(theta - 0.8) - step(theta - 1.2)
sign=[1, -1],
),
inits=[(mu=0, phi=0, pi=0, tau1=1, tau2=1), (mu=10, phi=10, pi=10, tau1=0.1, tau2=0.1)],
reference_results=(
equiv=(mean=0.998, std=0.04468),
mu=(mean=1.436, std=0.05751),
phi=(mean=-0.008613, std=0.05187),
sigma1=(mean=0.1102, std=0.03268),
),
)
5 changes: 5 additions & 0 deletions src/BUGSExamples/Volume_I/Rats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,9 @@ rats = (
var"beta.tau"=0.1,
),
],
reference_results=(
alpha0=(mean=106.6, std=3.66),
var"beta.c"=(mean=6.186, std=0.1086),
sigma=(mean=6.093, std=0.4643),
),
)
9 changes: 8 additions & 1 deletion src/BUGSExamples/Volume_I/Seeds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ sigma <- 1 / sqrt(tau)
b=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
),
],
)
reference_results=(
alpha0=(mean=-0.5499, std=0.1965),
alpha1=(mean=0.08902, std=0.3124),
alpha12=(mean=-0.841, std=0.4372),
alpha2=(mean=1.356, std=0.2772),
sigma=(mean=0.2922, std=0.1467),
),
)
3 changes: 3 additions & 0 deletions src/BUGSExamples/Volume_I/Stacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ sigma <- sqrt(1 / tau) # normal errors
(beta0=10, beta=[0, 0, 0], tau=0.1, phi=0.1),
(beta0=1.0, beta=[1.0, 1.0, 1.0], tau=1.0, phi=1.0),
],
reference_results=(
b0=(mean=-39.64, std=12.63), var"outlier[21]"=(mean=0.3324, std=0.4711)
),
)
108 changes: 54 additions & 54 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,65 @@ end

# AdvancedHMC

# test generation of parameter names
model = compile(
(@bugs begin
x[1:2] ~ dmnorm(mu[:], sigma[:, :])
x[3] ~ dnorm(0, 1)
y = x[1] + x[3]
end), (mu=[0, 0], sigma=[1 0; 0 1]), NamedTuple()
)
@testset "AdvancedHMC" begin
@testset "Generation of parameter names" begin
model = compile(
(@bugs begin
x[1:2] ~ dmnorm(mu[:], sigma[:, :])
x[3] ~ dnorm(0, 1)
y = x[1] + x[3]
end),
(mu=[0, 0], sigma=[1 0; 0 1]),
NamedTuple(),
)

ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
n_samples, n_adapts = 10, 0
D = LogDensityProblems.dimension(model);
initial_θ = rand(D);
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
n_samples, n_adapts = 10, 0
D = LogDensityProblems.dimension(model)
initial_θ = rand(D)
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)

@test samples_and_stats.name_map.parameters ==
[Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y]
@test samples_and_stats.name_map.parameters ==
[Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y]
end

# test inference result with Seeds
data = JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].data
inits = JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].inits[1]
model = JuliaBUGS.compile(JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].model_def, data, inits)
@testset "Inference results on examples: $m" for m in [:seeds, :rats, :equiv, :stacks]
data = JuliaBUGS.BUGSExamples.VOLUME_I[m].data
inits = JuliaBUGS.BUGSExamples.VOLUME_I[m].inits[1]
model = JuliaBUGS.compile(JuliaBUGS.BUGSExamples.VOLUME_I[m].model_def, data, inits)

ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))

n_samples, n_adapts = 2000, 1000
n_samples, n_adapts = 2000, 1000

D = LogDensityProblems.dimension(model);
initial_θ = rand(D);
D = LogDensityProblems.dimension(model)
initial_θ = rand(D)

samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)

@test summarize(samples_and_stats)[:alpha0].nt.mean[1] -0.5499 rtol = 0.1
@test summarize(samples_and_stats)[:alpha0].nt.std[1] 0.1965 rtol = 0.1

@test summarize(samples_and_stats)[:alpha1].nt.mean[1] 0.08902 rtol = 0.1
@test summarize(samples_and_stats)[:alpha1].nt.std[1] 0.3124 rtol = 0.1

@test summarize(samples_and_stats)[:alpha12].nt.mean[1] -0.841 rtol = 0.1
@test summarize(samples_and_stats)[:alpha12].nt.std[1] 0.4372 rtol = 0.1

@test summarize(samples_and_stats)[:alpha2].nt.mean[1] 1.356 rtol = 0.1
@test summarize(samples_and_stats)[:alpha2].nt.std[1] 0.2772 rtol = 0.1

@test summarize(samples_and_stats)[:sigma].nt.mean[1] 0.2922 rtol = 0.1
@test summarize(samples_and_stats)[:sigma].nt.std[1] 0.1467 rtol = 0.1
@assert JuliaBUGS.BUGSExamples.has_ground_truth(m) "No reference inference results for $m"
ref_inference_results = JuliaBUGS.BUGSExamples.VOLUME_I[m].reference_results
@testset "$m: $var" for var in keys(ref_inference_results)
@test summarize(samples_and_stats)[var].nt.mean[1]
ref_inference_results[var].mean rtol = 0.2
@test summarize(samples_and_stats)[var].nt.std[1]
ref_inference_results[var].std rtol = 0.2
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ using JuliaBUGS.BUGSPrimitives: mean
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MCMCChains
using Random
using ReverseDiff
using Setfield
using Test
using UnPack

Random.seed!(12345)

@testset "Function Unit Tests" begin
DocMeta.setdocmeta!(
JuliaBUGS,
Expand Down

0 comments on commit 1f5dd43

Please sign in to comment.