Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken case when variables are multi-dimensional and add tests for AdvancedHMC extension and inference #89

Merged
merged 23 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/BUGSExamples/BUGSExamples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,12 @@

const VOLUME_II = (birats=birats, eyes=eyes)

function has_ground_truth(m::Symbol)
if m in union(keys(VOLUME_I), keys(VOLUME_II))
return haskey(getfield(BUGSExamples, m), :reference_results)

Check warning on line 57 in src/BUGSExamples/BUGSExamples.jl

View check run for this annotation

Codecov / codecov/patch

src/BUGSExamples/BUGSExamples.jl#L55-L57

Added lines #L55 - L57 were not covered by tests
else
return false

Check warning on line 59 in src/BUGSExamples/BUGSExamples.jl

View check run for this annotation

Codecov / codecov/patch

src/BUGSExamples/BUGSExamples.jl#L59

Added line #L59 was not covered by tests
end
end

end
5 changes: 0 additions & 5 deletions src/BUGSExamples/Backgrounds/rats.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ This example is taken from section 6 of *Gelfand et al. (1990)*, and concerns 30
weights were measured weekly for five weeks. Part of the data is shown below, where $Y_{ij}$ is the
weight of the $i^{th}$ rat measured at age $x_j$.




<center>

$$\text{Weights } Y_{ij} \text{ of rat } i \text{ on day } x_j$$
Expand Down Expand Up @@ -44,5 +41,3 @@ be independent).

$ a_c $, $ \tau_a $, $ b_c $, $ \tau_b $, $ \tau_c $ are given independent "noninformative" priors. Interest particularly focuses on
the intercept at zero time (birth), denoted $ a_0 = a_c - b_c \cdot \bar{x} $.


6 changes: 6 additions & 0 deletions src/BUGSExamples/Volume_I/Equiv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,10 @@ equiv <- step(theta - 0.8) - step(theta - 1.2)
sign=[1, -1],
),
inits=[(mu=0, phi=0, pi=0, tau1=1, tau2=1), (mu=10, phi=10, pi=10, tau1=0.1, tau2=0.1)],
reference_results=(
equiv=(mean=0.998, std=0.04468),
mu=(mean=1.436, std=0.05751),
phi=(mean=-0.008613, std=0.05187),
sigma1=(mean=0.1102, std=0.03268),
),
)
5 changes: 5 additions & 0 deletions src/BUGSExamples/Volume_I/Rats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,9 @@ rats = (
var"beta.tau"=0.1,
),
],
reference_results=(
alpha0=(mean=106.6, std=3.66),
var"beta.c"=(mean=6.186, std=0.1086),
sigma=(mean=6.093, std=0.4643),
),
)
9 changes: 8 additions & 1 deletion src/BUGSExamples/Volume_I/Seeds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,11 @@ sigma <- 1 / sqrt(tau)
b=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
),
],
)
reference_results=(
alpha0=(mean=-0.5499, std=0.1965),
alpha1=(mean=0.08902, std=0.3124),
alpha12=(mean=-0.841, std=0.4372),
alpha2=(mean=1.356, std=0.2772),
sigma=(mean=0.2922, std=0.1467),
),
)
3 changes: 3 additions & 0 deletions src/BUGSExamples/Volume_I/Stacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ sigma <- sqrt(1 / tau) # normal errors
(beta0=10, beta=[0, 0, 0], tau=0.1, phi=0.1),
(beta0=1.0, beta=[1.0, 1.0, 1.0], tau=1.0, phi=1.0),
],
reference_results=(
b0=(mean=-39.64, std=12.63), var"outlier[21]"=(mean=0.3324, std=0.4711)
),
)
108 changes: 54 additions & 54 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,65 @@ end

# AdvancedHMC

# test generation of parameter names
model = compile(
(@bugs begin
x[1:2] ~ dmnorm(mu[:], sigma[:, :])
x[3] ~ dnorm(0, 1)
y = x[1] + x[3]
end), (mu=[0, 0], sigma=[1 0; 0 1]), NamedTuple()
)
@testset "AdvancedHMC" begin
@testset "Generation of parameter names" begin
model = compile(
(@bugs begin
x[1:2] ~ dmnorm(mu[:], sigma[:, :])
x[3] ~ dnorm(0, 1)
y = x[1] + x[3]
end),
(mu=[0, 0], sigma=[1 0; 0 1]),
NamedTuple(),
)

ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
n_samples, n_adapts = 10, 0
D = LogDensityProblems.dimension(model);
initial_θ = rand(D);
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
n_samples, n_adapts = 10, 0
D = LogDensityProblems.dimension(model)
initial_θ = rand(D)
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)

@test samples_and_stats.name_map.parameters ==
[Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y]
@test samples_and_stats.name_map.parameters ==
[Symbol("x[3]"), Symbol("x[1:2][1]"), Symbol("x[1:2][2]"), :y]
end

# test inference result with Seeds
data = JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].data
inits = JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].inits[1]
model = JuliaBUGS.compile(JuliaBUGS.BUGSExamples.VOLUME_I[:seeds].model_def, data, inits)
@testset "Inference results on examples: $m" for m in [:seeds, :rats, :equiv, :stacks]
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
data = JuliaBUGS.BUGSExamples.VOLUME_I[m].data
inits = JuliaBUGS.BUGSExamples.VOLUME_I[m].inits[1]
model = JuliaBUGS.compile(JuliaBUGS.BUGSExamples.VOLUME_I[m].model_def, data, inits)

ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))

n_samples, n_adapts = 2000, 1000
n_samples, n_adapts = 2000, 1000

D = LogDensityProblems.dimension(model);
initial_θ = rand(D);
D = LogDensityProblems.dimension(model)
initial_θ = rand(D)

samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)
samples_and_stats = AbstractMCMC.sample(
ad_model,
NUTS(0.8),
n_samples;
chain_type=Chains,
n_adapts=n_adapts,
init_params=initial_θ,
discard_initial=n_adapts,
)

@test summarize(samples_and_stats)[:alpha0].nt.mean[1] ≈ -0.5499 rtol = 0.1
@test summarize(samples_and_stats)[:alpha0].nt.std[1] ≈ 0.1965 rtol = 0.1

@test summarize(samples_and_stats)[:alpha1].nt.mean[1] ≈ 0.08902 rtol = 0.1
@test summarize(samples_and_stats)[:alpha1].nt.std[1] ≈ 0.3124 rtol = 0.1

@test summarize(samples_and_stats)[:alpha12].nt.mean[1] ≈ -0.841 rtol = 0.1
@test summarize(samples_and_stats)[:alpha12].nt.std[1] ≈ 0.4372 rtol = 0.1

@test summarize(samples_and_stats)[:alpha2].nt.mean[1] ≈ 1.356 rtol = 0.1
@test summarize(samples_and_stats)[:alpha2].nt.std[1] ≈ 0.2772 rtol = 0.1

@test summarize(samples_and_stats)[:sigma].nt.mean[1] ≈ 0.2922 rtol = 0.1
@test summarize(samples_and_stats)[:sigma].nt.std[1] ≈ 0.1467 rtol = 0.1
@assert JuliaBUGS.BUGSExamples.has_ground_truth(m) "No reference inference results for $m"
ref_inference_results = JuliaBUGS.BUGSExamples.VOLUME_I[m].reference_results
@testset "$m: $var" for var in keys(ref_inference_results)
@test summarize(samples_and_stats)[var].nt.mean[1] ≈
ref_inference_results[var].mean rtol = 0.2
@test summarize(samples_and_stats)[var].nt.std[1] ≈
ref_inference_results[var].std rtol = 0.2
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ using JuliaBUGS.BUGSPrimitives: mean
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MCMCChains
using Random
using ReverseDiff
using Setfield
using Test
using UnPack

Random.seed!(12345)

@testset "Function Unit Tests" begin
DocMeta.setdocmeta!(
JuliaBUGS,
Expand Down
Loading