Skip to content

Commit

Permalink
remove dep on DynamicPPL
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 23, 2024
1 parent ff49d61 commit 9ae8182
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 48 deletions.
7 changes: 2 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[weakdeps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
Expand All @@ -38,10 +37,9 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
[extensions]
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSDynamicPPLExt = ["DynamicPPL"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["DynamicPPL", "MCMCChains"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]

[compat]
ADTypes = "1.6"
Expand All @@ -54,7 +52,6 @@ BangBang = "0.4.1"
Bijectors = "0.13"
Distributions = "0.23.8, 0.24, 0.25"
Documenter = "0.27, 1"
DynamicPPL = "0.25, 0.26, 0.27, 0.28, 0.29"
GLMakie = "0.10"
GraphMakie = "0.5"
GraphPlot = "0.6"
Expand Down Expand Up @@ -86,4 +83,4 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractMCMC", "ADTypes", "AdvancedHMC", "AdvancedMH", "DynamicPPL", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
test = ["AbstractMCMC", "ADTypes", "AdvancedHMC", "AdvancedMH", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
40 changes: 0 additions & 40 deletions ext/JuliaBUGSDynamicPPLExt.jl

This file was deleted.

27 changes: 24 additions & 3 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using DynamicPPL
using AbstractMCMC
using MCMCChains: Chains

Expand Down Expand Up @@ -50,6 +49,28 @@ function JuliaBUGS.gen_chains(
)
end

# copied from DynamicPPL
varname_leaves(vn::VarName, ::Real) = [vn]
function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(VarName(vn, Accessors.IndexLens(Tuple(I)) getoptic(vn)), val[I])
for I in CartesianIndices(val)
)
end
function varname_leaves(vn::VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = Accessors.PropertyLens{sym}()
varname_leaves(VarName(vn, optic getoptic(vn)), optic(val))
end
return Iterators.flatten(iter)
end

function JuliaBUGS.gen_chains(
model::JuliaBUGS.BUGSModel,
samples,
Expand Down Expand Up @@ -84,13 +105,13 @@ function JuliaBUGS.gen_chains(

param_name_leaves = collect(
Iterators.flatten([
collect(DynamicPPL.varname_leaves(vn, param_vals[1][i])) for
collect(varname_leaves(vn, param_vals[1][i])) for
(i, vn) in enumerate(param_vars)
],),
)
generated_varname_leaves = collect(
Iterators.flatten([
collect(DynamicPPL.varname_leaves(vn, generated_quantities[1][i])) for
collect(varname_leaves(vn, generated_quantities[1][i])) for
(i, vn) in enumerate(generated_vars)
],),
)
Expand Down
17 changes: 17 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,23 @@ function getparams(model::BUGSModel)
return param_vals
end

function getparams_as_ordereddict(model::BUGSModel)
d = OrderedDict{VarName,Any}()
for v in model.parameters
if !model.transformed
d[v] = AbstractPPL.get(model.evaluation_env, v)
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
end
end
return d
end

"""
settrans(model::BUGSModel, bool::Bool=!(model.transformed))
Expand Down

0 comments on commit 9ae8182

Please sign in to comment.