From c5bf173caa9f3b526ce7cc2f1e92a1945fe7c545 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 29 Nov 2024 08:52:43 +0000 Subject: [PATCH 1/2] cleanup and add some comments --- src/model.jl | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 06f39ede..2607eaa9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -431,9 +431,8 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) end function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) - (; evaluation_env, g) = model - vi = deepcopy(evaluation_env) logp = 0.0 + evaluation_env = deepcopy(model.evaluation_env) for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] node_function = model.flattened_graph_node_data.node_function_vals[i] @@ -444,7 +443,16 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) else dist = node_function(model.evaluation_env, loop_vars) value = rand(rng, dist) # just sample from the prior - logp += logpdf(dist, value) + if model.transformed + # see below for why we need to transform the value + value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) + logp += + Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( + Bijectors.inverse(Bijectors.bijector(dist)), value_transformed + ) + else + logp += Distributions.logpdf(dist, value) + end evaluation_env = setindex!!(evaluation_env, value, vn) end end @@ -467,6 +475,8 @@ function AbstractPPL.evaluate!!(model::BUGSModel) if model.transformed # although the values stored in `evaluation_env` are in their original space, # here we behave as accepting a vector of parameters in the transformed space + # this is so that we have consistent logp values between + # (1) set values in original space then evaluate (2) directly evaluate with the values in transformed space value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) logp += Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( From 70da021944e14db6eb091d82b354ca514d62eae7 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Fri, 29 Nov 2024 08:53:04 +0000 Subject: [PATCH 2/2] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6ef056da..523c075e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.7.1" +version = "0.7.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"