From 783c063a333fc235c9ec6b49d419e404c2f88ff4 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 7 Sep 2023 14:15:13 +0100 Subject: [PATCH] Update tests, use `link` or `LogDensityFunction` --- src/JuliaBUGS.jl | 2 +- src/model.jl | 8 ++++---- test/logp_tests/binomial.jl | 6 +++--- test/logp_tests/blockers.jl | 8 ++++---- test/logp_tests/bones.jl | 8 ++++---- test/logp_tests/ddirich.jl | 0 test/logp_tests/dogs.jl | 8 ++++---- test/logp_tests/dwish.jl | 0 test/logp_tests/gamma.jl | 4 ++-- test/logp_tests/lkj.jl | 0 test/logp_tests/rats.jl | 21 +++++++++++++++++---- test/runtests.jl | 6 ++---- 12 files changed, 41 insertions(+), 30 deletions(-) delete mode 100644 test/logp_tests/ddirich.jl delete mode 100644 test/logp_tests/dwish.jl delete mode 100644 test/logp_tests/lkj.jl diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 29d6adde4..7092c1a6d 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -156,7 +156,7 @@ function compile(model_def::Expr, data, inits) vars, array_sizes, array_bitmap, link_functions, node_args, node_functions, dependencies = program!( NodeFunctions(vars, array_sizes, array_bitmap), model_def, merged_data ) - g = BUGSGraph(vars, link_functions, node_args, node_functions, dependencies) + g = BUGSGraph(vars, node_args, node_functions, dependencies) sorted_nodes = map(Base.Fix1(label_for, g), topological_sort(g)) return BUGSModel(g, sorted_nodes, vars, array_sizes, merged_data, inits) end diff --git a/src/model.jl b/src/model.jl index 2dd6d8ac7..bf0448bc4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -252,7 +252,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) value_transformed = transform(bijector(dist), value) logp += logpdf(dist, value) + - logabsdetjac(inverse(bijector(dist)), value_transformed) + logabsdetjac(Bijectors.inverse(bijector(dist)), value_transformed) else logp += logpdf(dist, value) end @@ -286,7 +286,7 @@ function AbstractPPL.evaluate!!( current_idx += l # TODO: this use `DynamicPPL.reconstruct`, which needs attention when decoupling from DynamicPPL value, logjac = DynamicPPL.with_logabsdet_jacobian_and_reconstruct( - inverse(bijector(dist)), dist, value_transformed + Bijectors.inverse(bijector(dist)), dist, value_transformed ) logp += logpdf(dist, value) + logjac vi = setindex!!(vi, value, vn) @@ -332,7 +332,7 @@ function AbstractPPL.evaluate!!(model::MarkovBlanketCoveredBUGSModel, ::DefaultC value_transformed = transform(bijector(dist), value) logp += logpdf(dist, value) + - logabsdetjac(inverse(bijector(dist)), value_transformed) + logabsdetjac(Bijectors.inverse(bijector(dist)), value_transformed) else logp += logpdf(dist, value) end @@ -372,7 +372,7 @@ function AbstractPPL.evaluate!!( current_idx += l # TODO: this use `DynamicPPL.reconstruct`, which needs attention when decoupling from DynamicPPL value, logjac = DynamicPPL.with_logabsdet_jacobian_and_reconstruct( - inverse(bijector(dist)), dist, value_transformed + Bijectors.inverse(bijector(dist)), dist, value_transformed ) logp += logpdf(dist, value) + logjac vi = setindex!!(vi, value, vn) diff --git a/test/logp_tests/binomial.jl b/test/logp_tests/binomial.jl index de5bef741..8d6638103 100644 --- a/test/logp_tests/binomial.jl +++ b/test/logp_tests/binomial.jl @@ -13,7 +13,7 @@ dppl_model = dppl_gamma_model() bugs_logp = JuliaBUGS.evaluate!!( - DynamicPPL.settrans!!(bugs_model, false), DefaultBUGSContext() + DynamicPPL.settrans!!(bugs_model, false), DefaultContext() ).logp params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model @@ -29,7 +29,7 @@ t_p = DynamicPPL.LogDensityFunction( @test bugs_logp ≈ LogDensityProblems.logdensity(p, [10.0]) rtol = 1E-6 bugs_logp = - JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultBUGSContext()).logp + JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp @test bugs_logp ≈ LogDensityProblems.logdensity(t_p, [transform(bijector(dbin(0.1, 10)), 10.0)]) rtol = - 1E-6 \ No newline at end of file + 1E-6 diff --git a/test/logp_tests/blockers.jl b/test/logp_tests/blockers.jl index 5f35a6452..382cc29b4 100644 --- a/test/logp_tests/blockers.jl +++ b/test/logp_tests/blockers.jl @@ -3,7 +3,7 @@ data = JuliaBUGS.BUGSExamples.VOLUME_I[:blockers].data inits = JuliaBUGS.BUGSExamples.VOLUME_I[:blockers].inits[1] bugs_model = compile(bugs_model_def, data, inits) -vi = JuliaBUGS.get_varinfo(bugs_model) +vi = bugs_model.varinfo @model function blockers(rc, rt, nc, nt, Num) d ~ dnorm(0.0, 1.0E-6) @@ -36,7 +36,7 @@ dppl_model = blockers(rc, rt, nc, nt, Num) bugs_logp = JuliaBUGS.evaluate!!( - DynamicPPL.settrans!!(bugs_model, false), DefaultBUGSContext() + DynamicPPL.settrans!!(bugs_model, false), DefaultContext() ).logp params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model @@ -50,11 +50,11 @@ dppl_logp = @test bugs_logp ≈ dppl_logp rtol = 1E-6 bugs_logp = - JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultBUGSContext()).logp + JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp dppl_logp = DynamicPPL.evaluate!!( dppl_model, - DynamicPPL.settrans!!(get_params_varinfo(bugs_model), true), + DynamicPPL.link!!(get_params_varinfo(bugs_model), dppl_model), DynamicPPL.DefaultContext(), )[2].logp @test bugs_logp ≈ dppl_logp rtol = 1E-6 diff --git a/test/logp_tests/bones.jl b/test/logp_tests/bones.jl index a384b0b2c..a5a9473a5 100644 --- a/test/logp_tests/bones.jl +++ b/test/logp_tests/bones.jl @@ -3,7 +3,7 @@ data = JuliaBUGS.BUGSExamples.VOLUME_I[:bones].data inits = JuliaBUGS.BUGSExamples.VOLUME_I[:bones].inits[1] bugs_model = compile(bugs_model_def, data, inits) -vi = JuliaBUGS.get_varinfo(bugs_model) +vi = bugs_model.varinfo @model function bones(grade, nChild, nInd, ncat, gamma, delta) theta = Vector{Real}(undef, nChild) @@ -38,7 +38,7 @@ dppl_model = bones(grade, nChild, nInd, ncat, gamma, delta) bugs_logp = JuliaBUGS.evaluate!!( - DynamicPPL.settrans!!(bugs_model, false), DefaultBUGSContext() + DynamicPPL.settrans!!(bugs_model, false), DefaultContext() ).logp params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model @@ -52,11 +52,11 @@ dppl_logp = @test bugs_logp ≈ dppl_logp rtol = 1E-6 bugs_logp = - JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultBUGSContext()).logp + JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp dppl_logp = DynamicPPL.evaluate!!( dppl_model, - DynamicPPL.settrans!!(get_params_varinfo(bugs_model), true), + DynamicPPL.link!!(get_params_varinfo(bugs_model), dppl_model), DynamicPPL.DefaultContext(), )[2].logp @test bugs_logp ≈ dppl_logp rtol = 1E-6 diff --git a/test/logp_tests/ddirich.jl b/test/logp_tests/ddirich.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/logp_tests/dogs.jl b/test/logp_tests/dogs.jl index 95f0e681c..f6eab4c09 100644 --- a/test/logp_tests/dogs.jl +++ b/test/logp_tests/dogs.jl @@ -3,7 +3,7 @@ data = JuliaBUGS.BUGSExamples.VOLUME_I[:dogs].data inits = JuliaBUGS.BUGSExamples.VOLUME_I[:dogs].inits[1] bugs_model = compile(bugs_model_def, data, inits) -vi = JuliaBUGS.get_varinfo(bugs_model) +vi = bugs_model.varinfo @model function dogs(Dogs, Trials, Y, y) # Initialize matrices @@ -41,7 +41,7 @@ dppl_model = dogs(Dogs, Trials, Y, 1 .- Y) bugs_logp = JuliaBUGS.evaluate!!( - DynamicPPL.settrans!!(bugs_model, false), DefaultBUGSContext() + DynamicPPL.settrans!!(bugs_model, false), DefaultContext() ).logp dppl_logp = @@ -52,11 +52,11 @@ dppl_logp = @test bugs_logp ≈ dppl_logp rtol = 1E-6 bugs_logp = - JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultBUGSContext()).logp + JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp dppl_logp = DynamicPPL.evaluate!!( dppl_model, - DynamicPPL.settrans!!(get_params_varinfo(bugs_model), true), + DynamicPPL.link!!(get_params_varinfo(bugs_model), dppl_model), DynamicPPL.DefaultContext(), )[2].logp @test bugs_logp ≈ dppl_logp rtol = 1E-6 diff --git a/test/logp_tests/dwish.jl b/test/logp_tests/dwish.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/logp_tests/gamma.jl b/test/logp_tests/gamma.jl index 10759236f..216163812 100644 --- a/test/logp_tests/gamma.jl +++ b/test/logp_tests/gamma.jl @@ -14,7 +14,7 @@ dppl_model = dppl_gamma_model() bugs_logp = JuliaBUGS.evaluate!!( - DynamicPPL.settrans!!(bugs_model, false), DefaultBUGSContext() + DynamicPPL.settrans!!(bugs_model, false), DefaultContext() ).logp params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model @@ -30,7 +30,7 @@ t_p = DynamicPPL.LogDensityFunction( @test bugs_logp ≈ LogDensityProblems.logdensity(p, [10.0]) rtol = 1E-6 bugs_logp = - JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultBUGSContext()).logp + JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp @test bugs_logp ≈ LogDensityProblems.logdensity(t_p, [transform(bijector(dgamma(0.001, 0.001)), 10.0)]) rtol = 1E-6 diff --git a/test/logp_tests/lkj.jl b/test/logp_tests/lkj.jl deleted file mode 100644 index e69de29bb..000000000 diff --git a/test/logp_tests/rats.jl b/test/logp_tests/rats.jl index 2358cccac..bbdfd68db 100644 --- a/test/logp_tests/rats.jl +++ b/test/logp_tests/rats.jl @@ -23,6 +23,7 @@ model_def = @bugs begin alpha0 = alpha_c - xbar * beta_c end bugs_model = compile(model_def, data, inits); +vi = bugs_model.varinfo @model function rats(Y, x, xbar, N, T) tau_c ~ dgamma(0.001, 0.001) @@ -53,15 +54,27 @@ bugs_model = compile(model_def, data, inits); end dppl_model = rats(Y, x, xbar, N, T) -vi, bugs_logp = get_vi_logp(bugs_model, false) +bugs_model = DynamicPPL.settrans!!(bugs_model, false) +bugs_logp = JuliaBUGS.evaluate!!(bugs_model, DefaultContext()).logp params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model @test params_in_dppl_model(dppl_model) == keys(params_vi) -vi, dppl_logp = get_vi_logp(dppl_model, vi, false) +dppl_logp = + DynamicPPL.evaluate!!( + dppl_model, DynamicPPL.settrans!!(vi, false), DynamicPPL.DefaultContext() + )[2].logp @test bugs_logp ≈ -174029.387 rtol = 1E-6 # reference value from ProbPALA @test bugs_logp ≈ dppl_logp rtol = 1E-6 -vi, bugs_logp = get_vi_logp(bugs_model, true) -vi, dppl_logp = get_vi_logp(dppl_model, vi, true) +dppl_logp = + DynamicPPL.evaluate!!( + dppl_model, + link!!(get_params_varinfo(bugs_model), dppl_model), + DynamicPPL.DefaultContext(), + )[2].logp +bugs_logp = JuliaBUGS.evaluate!!(DynamicPPL.settrans!!(bugs_model, true), DefaultContext()).logp @test bugs_logp ≈ dppl_logp rtol = 1E-6 + +@test bugs_model.param_length == + LogDensityProblems.dimension(DynamicPPL.LogDensityFunction(dppl_model)) diff --git a/test/runtests.jl b/test/runtests.jl index 8075197dc..6b1750b5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,7 @@ using JuliaBUGS: Stochastic, Logical, evaluate!!, - DefaultBUGSContext, + DefaultContext, BUGSGraph, stochastic_neighbors, stochastic_inneighbors, @@ -26,9 +26,7 @@ using JuliaBUGS: LogDensityContext, ConcreteNodeInfo, SimpleVarInfo, - get_params_varinfo, - get_varinfo, - transformation + get_params_varinfo using JuliaBUGS.BUGSPrimitives using JuliaBUGS.BUGSPrimitives: mean using LogDensityProblems, LogDensityProblemsAD