Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 7, 2024
1 parent 4fb1eab commit 91d0f0c
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,38 @@
abstract type AbstractBUGSModel end

"""
EvalCache{TNF,TNA,TV}
FlattenedGraphNodeData{TNF,TNA,TV}
Pre-compute the values of the nodes in the model to avoid lookups from MetaGraph.
"""
struct EvalCache{TNF,TNA,TV}
struct FlattenedGraphNodeData{TNF,TV}
sorted_nodes::Vector{<:VarName}
is_stochastic_vals::Vector{Bool}
is_observed_vals::Vector{Bool}
node_function_vals::TNF
node_args_vals::TNA
loop_vars_vals::TV
end

function EvalCache(sorted_nodes::Vector{<:VarName}, g::BUGSGraph)
function FlattenedGraphNodeData(
g::BUGSGraph,
sorted_nodes::Vector{<:VarName}=[label_for(g, node) for node in topological_sort(g)],
)
is_stochastic_vals = Array{Bool}(undef, length(sorted_nodes))
is_observed_vals = Array{Bool}(undef, length(sorted_nodes))
node_function_vals = []
node_args_vals = []
loop_vars_vals = []
node_function_vals = Array{Any}(undef, length(sorted_nodes))
loop_vars_vals = Array{Any}(undef, length(sorted_nodes))
for (i, vn) in enumerate(sorted_nodes)
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
(; is_stochastic, is_observed, node_function, loop_vars) = g[vn]
is_stochastic_vals[i] = is_stochastic
is_observed_vals[i] = is_observed
push!(node_function_vals, node_function)
push!(node_args_vals, Val(node_args))
push!(loop_vars_vals, loop_vars)
node_function_vals[i] = node_function
loop_vars_vals[i] = loop_vars
end
return EvalCache(
return FlattenedGraphNodeData(
sorted_nodes,
is_stochastic_vals,
is_observed_vals,
node_function_vals,
node_args_vals,
loop_vars_vals,
)
end
Expand All @@ -47,9 +46,8 @@ end
The `BUGSModel` object is used for inference and represents the output of compilation. It implements the
[`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface.
"""
struct BUGSModel{
base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TNA,TV
} <: AbstractBUGSModel
struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <:
AbstractBUGSModel
" Indicates whether the model parameters are in the transformed space. "
transformed::Bool

Expand All @@ -66,8 +64,8 @@ struct BUGSModel{
evaluation_env::T
"A vector containing the names of the model parameters (unobserved stochastic variables)."
parameters::Vector{<:VarName}
"An `EvalCache` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed."
eval_cache::EvalCache{TNF,TNA,TV}
"An `FlattenedGraphNodeData` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed."
flattened_graph_node_data::FlattenedGraphNodeData{TNF,TV}

"An instance of `BUGSGraph`, representing the dependency graph of the model."
g::BUGSGraph
Expand Down Expand Up @@ -116,14 +114,19 @@ function BUGSModel(
initial_params::NamedTuple=NamedTuple();
is_transformed::Bool=true,
)
sorted_nodes = VarName[label_for(g, node) for node in topological_sort(g)]
flattened_graph_node_data = FlattenedGraphNodeData(g)
parameters = VarName[]
untransformed_param_length, transformed_param_length = 0, 0
untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(),
Dict{VarName,Int}()

for vn in sorted_nodes
(; is_stochastic, is_observed, node_function, loop_vars) = g[vn]
for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip(
flattened_graph_node_data.sorted_nodes,
flattened_graph_node_data.is_stochastic_vals,
flattened_graph_node_data.is_observed_vals,
flattened_graph_node_data.node_function_vals,
flattened_graph_node_data.loop_vars_vals,
)
if !is_stochastic
value = Base.invokelatest(node_function, evaluation_env, loop_vars)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
Expand Down Expand Up @@ -168,7 +171,7 @@ function BUGSModel(
transformed_var_lengths,
evaluation_env,
parameters,
EvalCache(sorted_nodes, g),
flattened_graph_node_data,
g,
nothing,
)
Expand All @@ -189,7 +192,7 @@ function BUGSModel(
model.transformed_var_lengths,
evaluation_env,
parameters,
EvalCache(sorted_nodes, g),
FlattenedGraphNodeData(g, sorted_nodes),
g,
isnothing(model.base_model) ? model : model.base_model,
)
Expand All @@ -202,11 +205,13 @@ Initialize the model with a NamedTuple of initial values, the values are expecte
"""
function initialize!(model::BUGSModel, initial_params::NamedTuple)
check_input(initial_params)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip(
model.flattened_graph_node_data.sorted_nodes,
model.flattened_graph_node_data.is_stochastic_vals,
model.flattened_graph_node_data.is_observed_vals,
model.flattened_graph_node_data.node_function_vals,
model.flattened_graph_node_data.loop_vars_vals,
)
if !is_stochastic
value = Base.invokelatest(node_function, model.evaluation_env, loop_vars)
BangBang.@set!! model.evaluation_env = setindex!!(
Expand Down Expand Up @@ -343,11 +348,11 @@ function AbstractPPL.condition(
new_parameters = setdiff(model.parameters, var_group)

sorted_blanket_with_vars = if sorted_nodes isa Nothing
model.eval_cache.sorted_nodes
model.flattened_graph_node_data.sorted_nodes
else
filter(
vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters),
model.eval_cache.sorted_nodes,
model.flattened_graph_node_data.sorted_nodes,
)
end

Expand Down Expand Up @@ -375,14 +380,14 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})

new_parameters = [
v for
v in base_model.eval_cache.sorted_nodes if v in union(model.parameters, var_group)
v in base_model.flattened_graph_node_data.sorted_nodes if v in union(model.parameters, var_group)
] # keep the order

markov_blanket_with_vars = union(
markov_blanket(base_model.g, new_parameters), new_parameters
)
sorted_blanket_with_vars = filter(
vn -> vn in markov_blanket_with_vars, base_model.eval_cache.sorted_nodes
vn -> vn in markov_blanket_with_vars, base_model.flattened_graph_node_data.sorted_nodes
)

new_model = BUGSModel(
Expand All @@ -405,10 +410,12 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g) = model
vi = deepcopy(evaluation_env)
logp = 0.0
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
for (vn, is_stochastic, node_function, loop_vars) in zip(
model.flattened_graph_node_data.sorted_nodes,
model.flattened_graph_node_data.is_stochastic_vals,
model.flattened_graph_node_data.node_function_vals,
model.flattened_graph_node_data.loop_vars_vals,
)
if !is_stochastic
value = node_function(model.evaluation_env, loop_vars)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -425,10 +432,12 @@ end
function AbstractPPL.evaluate!!(model::BUGSModel)
logp = 0.0
evaluation_env = deepcopy(model.evaluation_env)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
for (vn, is_stochastic, node_function, loop_vars) in zip(
model.flattened_graph_node_data.sorted_nodes,
model.flattened_graph_node_data.is_stochastic_vals,
model.flattened_graph_node_data.node_function_vals,
model.flattened_graph_node_data.loop_vars_vals,
)
if !is_stochastic
value = node_function(model.evaluation_env, loop_vars)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand Down Expand Up @@ -461,11 +470,13 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect
evaluation_env = deepcopy(model.evaluation_env)
current_idx = 1
logp = 0.0
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip(
model.flattened_graph_node_data.sorted_nodes,
model.flattened_graph_node_data.is_stochastic_vals,
model.flattened_graph_node_data.is_observed_vals,
model.flattened_graph_node_data.node_function_vals,
model.flattened_graph_node_data.loop_vars_vals,
)
if !is_stochastic
value = node_function(evaluation_env, loop_vars)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
Expand Down

0 comments on commit 91d0f0c

Please sign in to comment.