From 2cdb826e62a9a95e0351276601020bd7d861a7e9 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 7 Oct 2024 08:44:11 +0100 Subject: [PATCH] remove dependency on DynamicPPL --- Project.toml | 8 +- docs/Project.toml | 1 - docs/make.jl | 1 - src/JuliaBUGS.jl | 12 +- src/logdensityproblems.jl | 2 +- src/model.jl | 280 +++++++++++++++----------------------- src/utils.jl | 92 ++++++++++--- test/gibbs.jl | 12 -- test/graphs.jl | 1 - test/log_density.jl | 45 ++++++ 10 files changed, 237 insertions(+), 217 deletions(-) diff --git a/Project.toml b/Project.toml index 2381d6a87..77e9913e3 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,11 @@ version = "0.6.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" JuliaSyntax = "70703baa-626e-46a2-a12c-08ffd08c73b4" @@ -36,6 +36,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" [extensions] JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"] JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"] +JuliaBUGSDynamicPPLExt = ["DynamicPPL"] JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"] JuliaBUGSGraphPlotExt = ["GraphPlot"] JuliaBUGSMCMCChainsExt = ["MCMCChains"] @@ -44,19 +45,21 @@ JuliaBUGSMCMCChainsExt = ["MCMCChains"] ADTypes = "1.6" AbstractMCMC = "5" AbstractPPL = "0.8.4" +Accessors = "0.1" AdvancedHMC = "0.6" AdvancedMH = "0.8" BangBang = "0.4.1" Bijectors = "0.13" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" -DynamicPPL = "0.25, 0.26, 0.27, 0.28" +DynamicPPL = "0.25, 0.26, 0.27, 0.28, 0.29" GLMakie = "0.10" GraphMakie = "0.5" GraphPlot = "0.6" Graphs = "1" JSON = "0.21" JuliaSyntax = "0.4" +LinearAlgebra = "1.9" LogDensityProblems = "2" LogDensityProblemsAD = "1" LogExpFunctions = "0.3" @@ -75,6 +78,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/docs/Project.toml b/docs/Project.toml index 931676183..ffff9644c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,5 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" [compat] diff --git a/docs/make.jl b/docs/make.jl index 1e6e6da1f..5831f47f4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,6 @@ using Documenter using JuliaBUGS using MetaGraphsNext using JuliaBUGS.BUGSPrimitives -using DynamicPPL: SimpleVarInfo makedocs(; sitename="JuliaBUGS.jl", diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 97ff9476f..2a7677fa3 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -2,11 +2,12 @@ module JuliaBUGS using AbstractMCMC using AbstractPPL +using Accessors using BangBang using Bijectors: Bijectors using Distributions -using DynamicPPL: DynamicPPL using Graphs, MetaGraphsNext +using LinearAlgebra using LogDensityProblems, LogDensityProblemsAD using MacroTools using Random @@ -154,8 +155,7 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N eval_env = semantic_analysis(model_def, data) model_def = concretize_colon_indexing(model_def, eval_env) g = create_graph(model_def, eval_env) - vi = DynamicPPL.SimpleVarInfo( - NamedTuple{keys(eval_env)}( + nonmissing_eval_env = NamedTuple{keys(eval_env)}( map( v -> begin if v === missing @@ -171,10 +171,8 @@ function compile(model_def::Expr, data::NamedTuple, initial_params::NamedTuple=N end, values(eval_env), ), - ), - 0.0, - ) - return BUGSModel(g, vi, initial_params) + ) + return BUGSModel(g, nonmissing_eval_env, initial_params) end """ diff --git a/src/logdensityproblems.jl b/src/logdensityproblems.jl index 041403002..ff600f616 100644 --- a/src/logdensityproblems.jl +++ b/src/logdensityproblems.jl @@ -1,5 +1,5 @@ function LogDensityProblems.logdensity(model::AbstractBUGSModel, x::AbstractArray) - vi, logp = evaluate!!(model, LogDensityContext(), x) + _, logp = evaluate!!(model, LogDensityContext(), x) return logp end diff --git a/src/model.jl b/src/model.jl index 10de064fc..190866958 100644 --- a/src/model.jl +++ b/src/model.jl @@ -9,7 +9,8 @@ abstract type AbstractBUGSModel end The `BUGSModel` object is used for inference and represents the output of compilation. It implements the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface. """ -struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing}} <: AbstractBUGSModel +struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple} <: + AbstractBUGSModel " Indicates whether the model parameters are in the transformed space. " transformed::Bool @@ -22,11 +23,8 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing}} <: AbstractBU "A dictionary mapping the names of the variables to their lengths in the transformed (unconstrained) space." transformed_var_lengths::Dict{<:VarName,Int} - "A `DynamicPPL.SimpleVarInfo` object containing the values of the variables in the model. - Note that the usage of `SimpleVarInfo` in JuliaBUGS is different from that of DynamicPPL: - In JuliaBUGS, `varinfo` contains all the values (DynamicPPL only contains values of model parameters), - and all the values in `varinfo` are always in the constrained space." - varinfo::DynamicPPL.SimpleVarInfo + "A `NamedTuple` containing the values of the variables in the model, all the values are in the constrained space." + evaluation_env::T "A vector containing the names of the model parameters (unobserved stochastic variables)." parameters::Vector{<:VarName} "A vector containing the names of all the variables in the model, sorted in topological order." @@ -56,7 +54,7 @@ function Base.show(io::IO, m::BUGSModel) println(io, " Model parameters:") println(io, " ", join(m.parameters, ", "), "\n") println(io, " Variable values:") - return println(io, "$(m.varinfo.values)") + return println(io, "$(m.evaluation_env)") end """ @@ -74,14 +72,14 @@ Return a vector of `VarName` containing the names of all the variables in the mo variables(m::BUGSModel) = collect(labels(m.g)) function prepare_arg_values( - args::Tuple{Vararg{Symbol}}, vi::DynamicPPL.SimpleVarInfo, loop_vars::NamedTuple{lvars} + args::Tuple{Vararg{Symbol}}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars} ) where {lvars} return NamedTuple{args}(Tuple( map(args) do arg if arg in lvars loop_vars[arg] else - vi[@varname($arg)] + AbstractPPL.get(evaluation_env, @varname($arg)) end end, )) @@ -89,7 +87,7 @@ end function BUGSModel( g::BUGSGraph, - vi::DynamicPPL.SimpleVarInfo, + evaluation_env::NamedTuple, initial_params::NamedTuple=NamedTuple(); is_transformed::Bool=true, ) @@ -101,10 +99,10 @@ function BUGSModel( for vn in sorted_nodes (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, vi, loop_vars) + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = Base.invokelatest(node_function; args...) - vi = DynamicPPL.BangBang.setindex!!(vi, value, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) elseif !is_observed push!(parameters, vn) dist = Base.invokelatest(node_function; args...) @@ -125,7 +123,7 @@ function BUGSModel( missing end if !ismissing(initialization) - vi = DynamicPPL.BangBang.setindex!!(vi, initialization, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn) else init_value = try rand(dist) @@ -134,7 +132,7 @@ function BUGSModel( "Failed to sample from the prior distribution of $vn, consider providing initialization values for $vn or it's parents: $(collect(MetaGraphsNext.inneighbor_labels(g, vn))...).", ) end - vi = DynamicPPL.BangBang.setindex!!(vi, init_value, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, init_value, vn) end end end @@ -144,7 +142,7 @@ function BUGSModel( transformed_param_length, untransformed_var_lengths, transformed_var_lengths, - vi, + evaluation_env, parameters, sorted_nodes, g, @@ -152,6 +150,26 @@ function BUGSModel( ) end +function BUGSModel( + model::BUGSModel, + parameters::Vector{<:VarName}, + sorted_nodes::Vector{<:VarName}, + evaluation_env::NamedTuple=model.evaluation_env, +) + return BUGSModel( + model.transformed, + sum(model.untransformed_var_lengths[v] for v in parameters), + sum(model.transformed_var_lengths[v] for v in parameters), + model.untransformed_var_lengths, + model.transformed_var_lengths, + evaluation_env, + parameters, + sorted_nodes, + model.g, + model.base_model, + ) +end + """ initialize!(model::BUGSModel, initial_params::NamedTuple) @@ -161,10 +179,12 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) for vn in model.sorted_nodes (; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn] - args = prepare_arg_values(node_args, model.varinfo, loop_vars) + args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) if !is_stochastic value = Base.invokelatest(node_function; args...) - BangBang.@set!! model.varinfo = setindex!!(model.varinfo, value, vn) + BangBang.@set!! model.evaluation_env = setindex!!( + model.evaluation_env, value, vn + ) elseif !is_observed initialization = try AbstractPPL.get(initial_params, vn) @@ -172,12 +192,14 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple) missing end if !ismissing(initialization) - BangBang.@set!! model.varinfo = setindex!!( - model.varinfo, initialization, vn + BangBang.@set!! model.evaluation_env = setindex!!( + model.evaluation_env, initialization, vn ) else - BangBang.@set!! model.varinfo = setindex!!( - model.varinfo, rand(Base.invokelatest(node_function; args...)), vn + BangBang.@set!! model.evaluation_env = setindex!!( + model.evaluation_env, + rand(Base.invokelatest(node_function; args...)), + vn, ) end end @@ -191,71 +213,41 @@ end Initialize the model with a vector of initial values, the values can be in transformed space if `model.transformed` is set to true. """ function initialize!(model::BUGSModel, initial_params::AbstractVector) - vi, logp = AbstractPPL.evaluate!!(model, LogDensityContext(), initial_params) + evaluation_env, logp = AbstractPPL.evaluate!!( + model, LogDensityContext(), initial_params + ) return BUGSModel( model.transformed, model.untransformed_param_length, model.transformed_param_length, model.untransformed_var_lengths, model.transformed_var_lengths, - DynamicPPL.setlogp!!(vi, logp), + evaluation_env, model.parameters, model.sorted_nodes, model.g, - model.base_model, + isnothing(model.base_model) ? model : model.base_model, ) end """ - get_params_varinfo(model::BUGSModel[, vi::DynamicPPL.SimpleVarInfo]) + getparams(model::BUGSModel) -Returns a `DynamicPPL.SimpleVarInfo` object containing only the parameter values of the model. -If `vi` is provided, it will be used; otherwise, `model.varinfo` will be used. +Extract the parameter values from the model as a flattened vector, in an order consistent with +the what `LogDensityProblems.logdensity` expects. """ -function get_params_varinfo(model::BUGSModel) - return get_params_varinfo(model, model.varinfo) -end -function get_params_varinfo(model::BUGSModel, vi::DynamicPPL.SimpleVarInfo) - if !model.transformed - d = Dict{VarName,Any}() - for param in model.parameters - d[param] = vi[param] - end - return DynamicPPL.SimpleVarInfo(d, vi.logp, DynamicPPL.NoTransformation()) +function getparams(model::BUGSModel) + param_length = if model.transformed + model.transformed_param_length else - d = Dict{VarName,Any}() - g = model.g - for v in model.sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[v] - if v in model.parameters - args = prepare_arg_values(node_args, vi, loop_vars) - dist = node_function(; args...) - linked_val = DynamicPPL.link(dist, vi[v]) - d[v] = linked_val - end - end - return DynamicPPL.SimpleVarInfo(d, vi.logp, DynamicPPL.DynamicTransformation()) + model.untransformed_param_length end -end -""" - getparams(model::BUGSModel[, vi::DynamicPPL.SimpleVarInfo]; transformed::Bool=false) - -Extract the parameter values from the model as a flattened vector, ordered topologically. -If `transformed` is set to true, the parameters are provided in the transformed space. -""" -function getparams(model::BUGSModel; transformed::Bool=false) - return getparams(model, model.varinfo; transformed=transformed) -end -function getparams(model::BUGSModel, vi::DynamicPPL.SimpleVarInfo; transformed::Bool=false) - param_vals = Vector{Float64}( - undef, - transformed ? model.transformed_param_length : model.untransformed_param_length, - ) + param_vals = Vector{Float64}(undef, param_length) pos = 1 for v in model.parameters - if !transformed - val = vi[v] + if !model.transformed + val = AbstractPPL.get(model.evaluation_env, v) len = model.untransformed_var_lengths[v] if val isa AbstractArray param_vals[pos:(pos + len - 1)] .= vec(val) @@ -264,14 +256,16 @@ function getparams(model::BUGSModel, vi::DynamicPPL.SimpleVarInfo; transformed:: end else (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(node_args, vi, loop_vars) + args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) dist = node_function(; args...) - linked_val = Bijectors.link(dist, vi[v]) + transformed_value = Bijectors.transform( + Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) + ) len = model.transformed_var_lengths[v] - if linked_val isa AbstractArray - param_vals[pos:(pos + len - 1)] .= vec(linked_val) + if transformed_value isa AbstractArray + param_vals[pos:(pos + len - 1)] .= vec(transformed_value) else - param_vals[pos] = linked_val + param_vals[pos] = transformed_value end end pos += len @@ -280,54 +274,13 @@ function getparams(model::BUGSModel, vi::DynamicPPL.SimpleVarInfo; transformed:: end """ - setparams!!(model::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false) - -Update the parameter values of a `BUGSModel` with new values provided in a flattened vector. - -Only the parameter values are updated, the values of logical variables are kept unchanged. + settrans(model::BUGSModel, bool::Bool=!(model.transformed)) -This function adopts the `BangBang` convention, i.e. it modifies the model in place when possible. +The `BUGSModel` contains information for evaluation in both transformed and untransformed spaces. The `transformed` field +indicates the current "mode" of the model. -# Arguments -- `m::BUGSModel`: The model to update. -- `flattened_values::AbstractVector`: A vector containing the new parameter values in a flattened form. -- `transformed::Bool=false`: Indicates whether the values in `flattened_values` are in the transformed space. - -# Returns -`DynamicPPL.SimpleVarInfo`: The updated `varinfo` with the new parameter values set. +This function enables switching the "mode" of the model. """ -function setparams!!( - model::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false -) - pos = 1 - vi = model.varinfo - for v in model.parameters - (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(node_args, vi, loop_vars) - dist = node_function(; args...) - - len = if transformed - model.transformed_var_lengths[v] - else - model.untransformed_var_lengths[v] - end - if transformed - linked_vals = flattened_values[pos:(pos + len - 1)] - sample_val = DynamicPPL.invlink_and_reconstruct(dist, linked_vals) - else - sample_val = flattened_values[pos:(pos + len - 1)] - end - vi = DynamicPPL.setindex!!(vi, sample_val, v) - pos += len - end - return vi -end - -function (model::BUGSModel)() - vi, _ = evaluate!!(model, SamplingContext()) - return get_params_varinfo(model, vi) -end - function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) return BangBang.setproperty!!(model, :transformed, bool) end @@ -337,15 +290,19 @@ function AbstractPPL.condition( d::Dict{<:VarName,<:Any}, sorted_nodes=Nothing, # support cached sorted Markov blanket nodes ) + new_evaluation_env = deepcopy(model.evaluation_env) + for (p, value) in d + new_evaluation_env = setindex!!(new_evaluation_env, value, p) + end return AbstractPPL.condition( - model, collect(keys(d)), update_varinfo(model.varinfo, d); sorted_nodes=sorted_nodes + model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes ) end function AbstractPPL.condition( model::BUGSModel, var_group::Vector{<:VarName}, - varinfo::DynamicPPL.SimpleVarInfo=model.varinfo, + evaluation_env::NamedTuple=model.evaluation_env, sorted_nodes=Nothing, ) check_var_group(var_group, model) @@ -362,16 +319,7 @@ function AbstractPPL.condition( end return BUGSModel( - model.transformed, - sum(model.untransformed_var_lengths[v] for v in new_parameters), - sum(model.transformed_var_lengths[v] for v in new_parameters), - model.untransformed_var_lengths, - model.transformed_var_lengths, - varinfo, - new_parameters, - sorted_blanket_with_vars, - model.g, - base_model, + model, new_parameters, sorted_blanket_with_vars, model.g, evaluation_env ) end @@ -386,16 +334,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) vn -> vn in union(markov_blanket(model.g, new_parameters)), base_model.sorted_nodes ) return BUGSModel( - model.transformed, - sum(model.untransformed_var_lengths[v] for v in new_parameters), - sum(model.transformed_var_lengths[v] for v in new_parameters), - model.untransformed_var_lengths, - model.transformed_var_lengths, - model.varinfo, - new_parameters, - sorted_blanket_with_vars, - model.g, - base_model, + model, new_parameters, sorted_blanket_with_vars, model.g, model.evaluation_env ) end @@ -408,14 +347,6 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) ) end -function update_varinfo(varinfo::DynamicPPL.SimpleVarInfo, d::Dict{VarName,<:Any}) - new_varinfo = deepcopy(varinfo) - for (p, value) in d - setindex!!(new_varinfo, value, p) - end - return new_varinfo -end - """ DefaultContext @@ -428,10 +359,9 @@ struct DefaultContext <: AbstractPPL.AbstractContext end Do an ancestral sampling of the model parameters. Also accumulate log joint density. """ -struct SamplingContext <: AbstractPPL.AbstractContext - rng::Random.AbstractRNG +@kwdef struct SamplingContext{T<:Random.AbstractRNG} <: AbstractPPL.AbstractContext + rng::T = Random.default_rng() end -SamplingContext() = SamplingContext(Random.default_rng()) """ LogDensityContext @@ -444,45 +374,44 @@ function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG) return evaluate!!(model, SamplingContext(rng)) end function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) - (; varinfo, g, sorted_nodes) = model - vi = deepcopy(varinfo) + (; evaluation_env, g, sorted_nodes) = model + vi = deepcopy(evaluation_env) logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, vi, loop_vars) + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) - vi = setindex!!(vi, value, vn) + evaluation_env = setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) value = rand(ctx.rng, dist) # just sample from the prior logp += logpdf(dist, value) - vi = setindex!!(vi, value, vn) + evaluation_env = setindex!!(evaluation_env, value, vn) end end - return vi, logp + return evaluation_env, logp end function AbstractPPL.evaluate!!(model::BUGSModel) return AbstractPPL.evaluate!!(model, DefaultContext()) end function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) - (; sorted_nodes, g, varinfo) = model - vi = deepcopy(varinfo) + (; sorted_nodes, g, evaluation_env) = model + vi = deepcopy(evaluation_env) logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, vi, loop_vars) + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) - vi = setindex!!(vi, value, vn) + evaluation_env = setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) - value = vi[vn] + value = AbstractPPL.get(evaluation_env, vn) if model.transformed - # although the values stored in `vi` are in their original space, - # when `DynamicTransformation`, we behave as accepting a vector of - # parameters in the transformed space + # although the values stored in `evaluation_env` are in their original space, + # here we behave as accepting a vector of parameters in the transformed space value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) logp += Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( @@ -493,7 +422,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) end end end - return vi, logp + return evaluation_env, logp end function AbstractPPL.evaluate!!( @@ -519,38 +448,43 @@ function AbstractPPL.evaluate!!( sorted_nodes = model.sorted_nodes g = model.g - vi = deepcopy(model.varinfo) + evaluation_env = deepcopy(model.evaluation_env) current_idx = 1 logp = 0.0 for vn in sorted_nodes (; is_stochastic, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(node_args, vi, loop_vars) + args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) - vi = setindex!!(vi, value, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) if vn in model.parameters l = var_lengths[vn] if model.transformed - value, logjac = DynamicPPL.with_logabsdet_jacobian_and_reconstruct( - Bijectors.inverse(Bijectors.bijector(dist)), + b = Bijectors.bijector(dist) + b_inv = Bijectors.inverse(b) + reconstructed_value = reconstruct( + b_inv, dist, flattened_values[current_idx:(current_idx + l - 1)], ) + value, logjac = Bijectors.with_logabsdet_jacobian( + b_inv, reconstructed_value + ) else - value = DynamicPPL.reconstruct( + value = reconstruct( dist, flattened_values[current_idx:(current_idx + l - 1)] ) logjac = 0.0 end current_idx += l logp += logpdf(dist, value) + logjac - vi = setindex!!(vi, value, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else - logp += logpdf(dist, vi[vn]) + logp += logpdf(dist, AbstractPPL.get(evaluation_env, vn)) end end end - return vi, logp + return evaluation_env, logp end diff --git a/src/utils.jl b/src/utils.jl index e2a463a8b..cbe4c248c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -537,25 +537,79 @@ end end # module -# TODO: can't remove even with the `possible` fix in DynamicPPL, still seems to have eltype inference issue causing AD errors -# Resolves: setindex!!([1 2; 3 4], [2 3; 4 5], 1:2, 1:2) # returns 2×2 Matrix{Any} -# Alternatively, can overload BangBang.possible( -# ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg -# ) -# to allow mutation, but the current solution seems create less possible problems, albeit less efficient. -function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...) - T = promote_type(eltype(xs), eltype(v)) - ys = similar(xs, T) - if eltype(xs) !== Union{} - copy!(ys, xs) - end - ys[I...] = v - return ys +# # TODO: can't remove even with the `possible` fix in DynamicPPL, still seems to have eltype inference issue causing AD errors +# # Resolves: setindex!!([1 2; 3 4], [2 3; 4 5], 1:2, 1:2) # returns 2×2 Matrix{Any} +# # Alternatively, can overload BangBang.possible( +# # ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg +# # ) +# # to allow mutation, but the current solution seems create less possible problems, albeit less efficient. +# function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...) +# T = promote_type(eltype(xs), eltype(v)) +# ys = similar(xs, T) +# if eltype(xs) !== Union{} +# copy!(ys, xs) +# end +# ys[I...] = v +# return ys +# end + +function BangBang.setindex!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} + optic = BangBang.prefermutation( + AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() + ) + return Accessors.set(nt, optic, val) end -# this will prefer to mutate `nt` in-place -function BangBang.setindex!!(nt::NamedTuple, value, vn::VarName{sym}) where {sym} - return Accessors.set( - nt, BangBang.prefermutation(PropertyLens{sym}() ⨟ getoptic(vn)), value - ) +""" + reconstruct([f, ]dist, val) + +Reconstruct `val` so that it's compatible with `dist`. + +If `f` is also provided, the reconstruct value will be +such that `f(reconstruct_val)` is compatible with `dist`. +""" +reconstruct(f, dist, val) = reconstruct(dist, val) + +# No-op versions. +reconstruct(::UnivariateDistribution, val::Real) = val +reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) +reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) +function reconstruct( + ::Distribution{ArrayLikeVariate{N}}, val::AbstractArray{<:Real,N} +) where {N} + return copy(val) end + +function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real}) + return reconstruct(dist, Matrix(reshape(val, size(dist)))) +end +reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real}) = Cholesky(val, dist.uplo, 0) +reconstruct(::LKJCholesky, val::Cholesky) = val + +# NOTE: Necessary to handle product distributions of `Dirichlet` and similar. +function reconstruct( + ::Bijectors.Inverse{<:Bijectors.SimplexBijector}, dist, val::AbstractVector +) + (d, ns...) = size(dist) + return reshape(val, d - 1, ns...) +end +function reconstruct( + ::Bijectors.Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector +) + return copy(val) +end +function reconstruct( + ::Bijectors.Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector +) + return copy(val) +end +function reconstruct( + ::Bijectors.Inverse{Bijectors.PDVecBijector}, ::MatrixDistribution, val::AbstractVector +) + return copy(val) +end + +reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) +reconstruct(::Tuple{}, val::AbstractVector) = val[1] +reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) +reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) diff --git a/test/gibbs.jl b/test/gibbs.jl index 99ad72810..d07428604 100644 --- a/test/gibbs.jl +++ b/test/gibbs.jl @@ -1,5 +1,4 @@ using JuliaBUGS: MHFromPrior, Gibbs -using DynamicPPL: DynamicPPL @testset "Simple gibbs" begin model_def = @bugs begin @@ -29,17 +28,6 @@ using DynamicPPL: DynamicPPL ) model = compile(model_def, data, (;)) - # use NamedTuple for SimpleVarinfo - model = JuliaBUGS.BangBang.setproperty!!( - model, - :varinfo, - begin - vi = model.varinfo - DynamicPPL.SimpleVarInfo( - DynamicPPL.values_as(vi, NamedTuple), vi.logp, vi.transformation - ) - end, - ) # single step p_s, st_init = AbstractMCMC.step( diff --git a/test/graphs.jl b/test/graphs.jl index e99d77081..7762e4c0e 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -14,7 +14,6 @@ test_model = @bugs begin l ~ dnorm(0, 1) end -# construct a SimpleVarInfo inits = ( a=1.0, b=2.0, diff --git a/test/log_density.jl b/test/log_density.jl index 9808208e0..a64617801 100644 --- a/test/log_density.jl +++ b/test/log_density.jl @@ -25,6 +25,11 @@ end # the bijector of dbin is the identity, so the log density should be the same @test _logjoint(untransformed_model) ≈ reference_logp_untransformed rtol = 1E-6 @test _logjoint(transformed_model) ≈ reference_logp_transformed rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + reference_logp_transformed rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + reference_logp_untransformed rtol = 1E-6 end @testset "dgamma (Gamma)" begin @@ -46,6 +51,11 @@ end @test _logjoint(untransformed_model) ≈ reference_logp_untransformed rtol = 1E-6 @test _logjoint(transformed_model) ≈ reference_logp_transformed rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + reference_logp_transformed rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + reference_logp_untransformed rtol = 1E-6 end @testset "ddirich (Dirichlet)" begin @@ -69,6 +79,11 @@ end @test _logjoint(untransformed_model) ≈ reference_logp_untransformed rtol = 1E-6 @test _logjoint(transformed_model) ≈ reference_logp_transformed rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + reference_logp_transformed rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + reference_logp_untransformed rtol = 1E-6 end @testset "dwish (Wishart)" begin @@ -99,6 +114,11 @@ end @test _logjoint(untransformed_model) ≈ reference_logp_untransformed rtol = 1E-6 @test _logjoint(transformed_model) ≈ reference_logp_transformed rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + reference_logp_transformed rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + reference_logp_untransformed rtol = 1E-6 end @testset "lkj (LKJ)" begin @@ -123,6 +143,11 @@ end @test _logjoint(untransformed_model) ≈ reference_logp_untransformed rtol = 1E-6 @test _logjoint(transformed_model) ≈ reference_logp_transformed rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + reference_logp_transformed rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + reference_logp_untransformed rtol = 1E-6 end end end @@ -134,6 +159,11 @@ end untransformed_model = JuliaBUGS.settrans(transformed_model, false) @test _logjoint(untransformed_model) ≈ -174029.38703951868 rtol = 1E-6 @test _logjoint(transformed_model) ≈ -174029.38703951868 rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + -174029.38703951868 rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + -174029.38703951868 rtol = 1E-6 end @testset "blockers" begin @@ -142,6 +172,11 @@ end untransformed_model = JuliaBUGS.settrans(transformed_model, false) @test _logjoint(untransformed_model) ≈ -8418.416388326123 rtol = 1E-6 @test _logjoint(transformed_model) ≈ -8418.416388326123 rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + -8418.416388326123 rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + -8418.416388326123 rtol = 1E-6 end @testset "bones" begin @@ -150,6 +185,11 @@ end untransformed_model = JuliaBUGS.settrans(transformed_model, false) @test _logjoint(untransformed_model) ≈ -161.6492002285034 rtol = 1E-6 @test _logjoint(transformed_model) ≈ -161.6492002285034 rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + -161.6492002285034 rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + -161.6492002285034 rtol = 1E-6 end @testset "dogs" begin @@ -158,6 +198,11 @@ end untransformed_model = JuliaBUGS.settrans(transformed_model, false) @test _logjoint(untransformed_model) ≈ -1243.188922285352 rtol = 1E-6 @test _logjoint(transformed_model) ≈ -1243.3996613167667 rtol = 1E-6 + + @test LogDensityProblems.logdensity(transformed_model, JuliaBUGS.getparams(transformed_model)) ≈ + -1243.3996613167667 rtol = 1E-6 + @test LogDensityProblems.logdensity(untransformed_model, JuliaBUGS.getparams(untransformed_model)) ≈ + -1243.188922285352 rtol = 1E-6 end end