Skip to content

Commit

Permalink
update the package extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 7, 2024
1 parent 2cdb826 commit 90c6b3e
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSDynamicPPLExt = ["DynamicPPL"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
JuliaBUGSMCMCChainsExt = ["DynamicPPL", "MCMCChains"]

[compat]
ADTypes = "1.6"
Expand Down
10 changes: 5 additions & 5 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using JuliaBUGS
using JuliaBUGS:
AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Bijectors
Expand Down Expand Up @@ -45,14 +44,15 @@ end
function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC
)
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
)
t, s = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
logdensitymodel,
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true), # for more advanced usage, probably save the state or transition
initial_params=JuliaBUGS.getparams(cond_model),
)
return JuliaBUGS.setparams!!(cond_model, t.z.θ; transformed=true)
end
Expand Down
12 changes: 6 additions & 6 deletions ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using AdvancedMH
using JuliaBUGS
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Random
Expand Down Expand Up @@ -40,16 +39,17 @@ end
function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::AdvancedMH.MHSampler
)
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
)
t, s = AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
),
logdensitymodel,
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model; transformed=true),
initial_params=JuliaBUGS.getparams(cond_model),
)
return JuliaBUGS.setparams!!(cond_model, t.params; transformed=true)
return JuliaBUGS.initialize!(cond_model, t.params)
end

end
40 changes: 40 additions & 0 deletions ext/JuliaBUGSDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module JuliaBUGSDynamicPPLExt

using DynamicPPL: DynamicPPL, OrderedDict
using JuliaBUGS: JuliaBUGS, Bijectors, VarName

"""
get_params_varinfo(model::JuliaBUGS.BUGSModel[, evaluation_env::NamedTuple])
Returns a `DynamicPPL.SimpleVarInfo` object containing only the parameter values of the model.
If `evaluation_env` is provided, it will be used; otherwise, `model.evaluation_env` will be used.
"""
function get_params_varinfo(
model::JuliaBUGS.BUGSModel, evaluation_env::NT = model.evaluation_env
) where {NT <: NamedTuple}
d = OrderedDict{VarName,Any}()
for v in model.parameters
if !model.transformed
d[v] = AbstractPPL.get(evaluation_env, v)
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(evaluation_env, v)
)
end
end
logp = JuliaBUGS.evaluate!!(model, JuliaBUGS.DefaultContext())[2]
return DynamicPPL.SimpleVarInfo(
d,
logp,
if model.transformed
DynamicPPL.DynamicTransformation()
else
DynamicPPL.NoTransformation()
end,
)
end

end # module
2 changes: 1 addition & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, eval
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.DynamicPPL
using DynamicPPL
using AbstractMCMC
using MCMCChains: Chains

Expand Down
37 changes: 16 additions & 21 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ struct MHFromPrior <: AbstractMCMC.AbstractSampler end

abstract type AbstractGibbsState end

struct GibbsState <: AbstractGibbsState
varinfo::DynamicPPL.SimpleVarInfo
struct GibbsState{NT<:NamedTuple} <: AbstractGibbsState
values::NT
conditioning_schedule::Dict
sorted_nodes_cache::Dict
end
Expand All @@ -25,7 +25,7 @@ function AbstractMCMC.step(
model=l_model.logdensity,
kwargs...,
)
vi = deepcopy(model.varinfo)
values = deepcopy(model.evaluation_env)
sorted_nodes_cache = Dict{Any,Any}()

conditioning_schedule = Dict()
Expand All @@ -39,8 +39,8 @@ function AbstractMCMC.step(
sorted_nodes_cache[vs] = ensure_vector(cond_model.sorted_nodes)
end

return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, conditioning_schedule, sorted_nodes_cache)
transition = JuliaBUGS.getparams(model)
return transition, GibbsState(values, conditioning_schedule, sorted_nodes_cache)
end

function AbstractMCMC.step(
Expand All @@ -51,34 +51,29 @@ function AbstractMCMC.step(
model=l_model.logdensity,
kwargs...,
)
vi = state.varinfo
values = state.values
for vs in keys(state.conditioning_schedule)
cond_model = AbstractPPL.condition(model, vs, vi, state.sorted_nodes_cache[vs])
vi = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
cond_model = AbstractPPL.condition(model, vs, values, state.sorted_nodes_cache[vs])
values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
end
return getparams(model, vi; transformed=model.transformed),
GibbsState(vi, state.conditioning_schedule, state.sorted_nodes_cache)
return values, GibbsState(values, state.conditioning_schedule, state.sorted_nodes_cache)
end

function gibbs_internal end

function gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::MHFromPrior
)
transformed_original = getparams(cond_model, cond_model.varinfo; transformed=true)
transformed_proposal = getparams(
cond_model, evaluate!!(cond_model, SamplingContext())[1]; transformed=true
)
function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior)
transformed_original = JuliaBUGS.getparams(cond_model)
transformed_proposal = JuliaBUGS.getparams(cond_model)

vi_proposed, logp_proposed = evaluate!!(
values_proposed, logp_proposed = evaluate!!(
cond_model, LogDensityContext(), transformed_proposal
)
vi, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)
values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)

if logp_proposed - logp > log(rand(rng))
vi = vi_proposed
values = values_proposed
end
return vi
return values
end

function AbstractMCMC.bundle_samples(
Expand Down

0 comments on commit 90c6b3e

Please sign in to comment.