diff --git a/Project.toml b/Project.toml index 6ef056da..523c075e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.7.1" +version = "0.7.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/model.jl b/src/model.jl index 06f39ede..2607eaa9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -431,9 +431,8 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) end function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) - (; evaluation_env, g) = model - vi = deepcopy(evaluation_env) logp = 0.0 + evaluation_env = deepcopy(model.evaluation_env) for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] node_function = model.flattened_graph_node_data.node_function_vals[i] @@ -444,7 +443,16 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) else dist = node_function(model.evaluation_env, loop_vars) value = rand(rng, dist) # just sample from the prior - logp += logpdf(dist, value) + if model.transformed + # see below for why we need to transform the value + value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) + logp += + Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( + Bijectors.inverse(Bijectors.bijector(dist)), value_transformed + ) + else + logp += Distributions.logpdf(dist, value) + end evaluation_env = setindex!!(evaluation_env, value, vn) end end @@ -467,6 +475,8 @@ function AbstractPPL.evaluate!!(model::BUGSModel) if model.transformed # although the values stored in `evaluation_env` are in their original space, # here we behave as accepting a vector of parameters in the transformed space + # this is so that we have consistent logp values between + # (1) set values in original space then evaluate (2) directly evaluate with the values in transformed space value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) logp += Distributions.logpdf(dist, value) + Bijectors.logabsdetjac(