diff --git a/src/model.jl b/src/model.jl index 9f634fd4e..223494007 100644 --- a/src/model.jl +++ b/src/model.jl @@ -4,39 +4,38 @@ abstract type AbstractBUGSModel end """ - EvalCache{TNF,TNA,TV} + FlattenedGraphNodeData{TNF,TNA,TV} Pre-compute the values of the nodes in the model to avoid lookups from MetaGraph. """ -struct EvalCache{TNF,TNA,TV} +struct FlattenedGraphNodeData{TNF,TV} sorted_nodes::Vector{<:VarName} is_stochastic_vals::Vector{Bool} is_observed_vals::Vector{Bool} node_function_vals::TNF - node_args_vals::TNA loop_vars_vals::TV end -function EvalCache(sorted_nodes::Vector{<:VarName}, g::BUGSGraph) +function FlattenedGraphNodeData( + g::BUGSGraph, + sorted_nodes::Vector{<:VarName}=[label_for(g, node) for node in topological_sort(g)], +) is_stochastic_vals = Array{Bool}(undef, length(sorted_nodes)) is_observed_vals = Array{Bool}(undef, length(sorted_nodes)) - node_function_vals = [] - node_args_vals = [] - loop_vars_vals = [] + node_function_vals = Array{Any}(undef, length(sorted_nodes)) + loop_vars_vals = Array{Any}(undef, length(sorted_nodes)) for (i, vn) in enumerate(sorted_nodes) - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] + (; is_stochastic, is_observed, node_function, loop_vars) = g[vn] is_stochastic_vals[i] = is_stochastic is_observed_vals[i] = is_observed - push!(node_function_vals, node_function) - push!(node_args_vals, Val(node_args)) - push!(loop_vars_vals, loop_vars) + node_function_vals[i] = node_function + loop_vars_vals[i] = loop_vars end - return EvalCache( + return FlattenedGraphNodeData( sorted_nodes, is_stochastic_vals, is_observed_vals, node_function_vals, - node_args_vals, loop_vars_vals, ) end @@ -47,9 +46,8 @@ end The `BUGSModel` object is used for inference and represents the output of compilation. It implements the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface. """ -struct BUGSModel{ - base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TNA,TV -} <: AbstractBUGSModel +struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <: + AbstractBUGSModel " Indicates whether the model parameters are in the transformed space. " transformed::Bool @@ -66,8 +64,8 @@ struct BUGSModel{ evaluation_env::T "A vector containing the names of the model parameters (unobserved stochastic variables)." parameters::Vector{<:VarName} - "An `EvalCache` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed." - eval_cache::EvalCache{TNF,TNA,TV} + "An `FlattenedGraphNodeData` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed." + flattened_graph_node_data::FlattenedGraphNodeData{TNF,TV} "An instance of `BUGSGraph`, representing the dependency graph of the model." g::BUGSGraph @@ -116,14 +114,19 @@ function BUGSModel( initial_params::NamedTuple=NamedTuple(); is_transformed::Bool=true, ) - sorted_nodes = VarName[label_for(g, node) for node in topological_sort(g)] + flattened_graph_node_data = FlattenedGraphNodeData(g) parameters = VarName[] untransformed_param_length, transformed_param_length = 0, 0 untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(), Dict{VarName,Int}() - for vn in sorted_nodes - (; is_stochastic, is_observed, node_function, loop_vars) = g[vn] + for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip( + flattened_graph_node_data.sorted_nodes, + flattened_graph_node_data.is_stochastic_vals, + flattened_graph_node_data.is_observed_vals, + flattened_graph_node_data.node_function_vals, + flattened_graph_node_data.loop_vars_vals, + ) if !is_stochastic value = Base.invokelatest(node_function, evaluation_env, loop_vars) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) @@ -168,7 +171,7 @@ function BUGSModel( transformed_var_lengths, evaluation_env, parameters, - EvalCache(sorted_nodes, g), + flattened_graph_node_data, g, nothing, ) @@ -189,7 +192,7 @@ function BUGSModel( model.transformed_var_lengths, evaluation_env, parameters, - EvalCache(sorted_nodes, g), + FlattenedGraphNodeData(g, sorted_nodes), g, isnothing(model.base_model) ? model : model.base_model, ) @@ -202,11 +205,13 @@ Initialize the model with a NamedTuple of initial values, the values are expecte """ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - is_observed = model.eval_cache.is_observed_vals[i] - node_function = model.eval_cache.node_function_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] + for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip( + model.flattened_graph_node_data.sorted_nodes, + model.flattened_graph_node_data.is_stochastic_vals, + model.flattened_graph_node_data.is_observed_vals, + model.flattened_graph_node_data.node_function_vals, + model.flattened_graph_node_data.loop_vars_vals, + ) if !is_stochastic value = Base.invokelatest(node_function, model.evaluation_env, loop_vars) BangBang.@set!! model.evaluation_env = setindex!!( @@ -343,11 +348,11 @@ function AbstractPPL.condition( new_parameters = setdiff(model.parameters, var_group) sorted_blanket_with_vars = if sorted_nodes isa Nothing - model.eval_cache.sorted_nodes + model.flattened_graph_node_data.sorted_nodes else filter( vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters), - model.eval_cache.sorted_nodes, + model.flattened_graph_node_data.sorted_nodes, ) end @@ -375,14 +380,14 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) new_parameters = [ v for - v in base_model.eval_cache.sorted_nodes if v in union(model.parameters, var_group) + v in base_model.flattened_graph_node_data.sorted_nodes if v in union(model.parameters, var_group) ] # keep the order markov_blanket_with_vars = union( markov_blanket(base_model.g, new_parameters), new_parameters ) sorted_blanket_with_vars = filter( - vn -> vn in markov_blanket_with_vars, base_model.eval_cache.sorted_nodes + vn -> vn in markov_blanket_with_vars, base_model.flattened_graph_node_data.sorted_nodes ) new_model = BUGSModel( @@ -405,10 +410,12 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) (; evaluation_env, g) = model vi = deepcopy(evaluation_env) logp = 0.0 - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - node_function = model.eval_cache.node_function_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] + for (vn, is_stochastic, node_function, loop_vars) in zip( + model.flattened_graph_node_data.sorted_nodes, + model.flattened_graph_node_data.is_stochastic_vals, + model.flattened_graph_node_data.node_function_vals, + model.flattened_graph_node_data.loop_vars_vals, + ) if !is_stochastic value = node_function(model.evaluation_env, loop_vars) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -425,10 +432,12 @@ end function AbstractPPL.evaluate!!(model::BUGSModel) logp = 0.0 evaluation_env = deepcopy(model.evaluation_env) - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - node_function = model.eval_cache.node_function_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] + for (vn, is_stochastic, node_function, loop_vars) in zip( + model.flattened_graph_node_data.sorted_nodes, + model.flattened_graph_node_data.is_stochastic_vals, + model.flattened_graph_node_data.node_function_vals, + model.flattened_graph_node_data.loop_vars_vals, + ) if !is_stochastic value = node_function(model.evaluation_env, loop_vars) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -461,11 +470,13 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect evaluation_env = deepcopy(model.evaluation_env) current_idx = 1 logp = 0.0 - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - is_observed = model.eval_cache.is_observed_vals[i] - node_function = model.eval_cache.node_function_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] + for (vn, is_stochastic, is_observed, node_function, loop_vars) in zip( + model.flattened_graph_node_data.sorted_nodes, + model.flattened_graph_node_data.is_stochastic_vals, + model.flattened_graph_node_data.is_observed_vals, + model.flattened_graph_node_data.node_function_vals, + model.flattened_graph_node_data.loop_vars_vals, + ) if !is_stochastic value = node_function(evaluation_env, loop_vars) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)