From 556c5be223745c7d3784bd94ae18f15cfffa0d29 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Tue, 12 Mar 2024 08:52:07 +0000 Subject: [PATCH] Fix `cumulative` and `density` function (#153) * `deviance` function is not supported and user will be warned when using it. * Compare to `multibugs`: multibugs doesn't seem to support `cumulative` and `density`, but provide individual cdf and pdf functions. We can also stop supporting these two functions, and direct user to use `cdf` and `pdf`. * For now, a new member `distributions` is added to `BUGSModel` to track all the distributions: using Dict right now because non-scalar rv is problematic if we store distributions in the same way we store values. `cumulative` and `density` are handled in special way in `_eval` function to allow access to distribution info at run time. Maybe this can be generalized, the distribution can be helpful introspect info Fix: https://github.com/TuringLang/JuliaBUGS.jl/issues/137 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler_pass.jl | 137 +++++++++++++++++++++++++++---------- src/model.jl | 30 ++++---- src/parser/bugs_macro.jl | 72 +------------------ src/utils.jl | 58 ++++++++++++---- test/cumulative_density.jl | 39 +++++++++++ test/runtests.jl | 2 + 6 files changed, 208 insertions(+), 130 deletions(-) create mode 100644 test/cumulative_density.jl diff --git a/src/compiler_pass.jl b/src/compiler_pass.jl index dace5b474..a6a3ac3c6 100644 --- a/src/compiler_pass.jl +++ b/src/compiler_pass.jl @@ -631,27 +631,65 @@ function evaluate_and_track_dependencies(var::Expr, env) push!(args, (var.args[1], ())) push!(deps, (var.args[1], ())) return Expr(var.head, var.args[1], idxs...), deps, args - else # function call - fun_args = [] - for i in 2:length(var.args) - e, d, a = evaluate_and_track_dependencies(var.args[i], env) - push!(fun_args, e) + elseif Meta.isexpr(var, :call) + if var.args[1] === :cumulative || var.args[1] === :density + arg1, arg2 = var.args[2:3] + arg1 = if arg1 isa Symbol + push!(deps, arg1) + # no need to add to arg, as the value doesn't matter + arg1 + elseif Meta.isexpr(arg1, :ref) + v, indices... = arg1.args + for i in eachindex(indices) + e, d, a = evaluate_and_track_dependencies(indices[i], env) + union!(deps, d) + union!(args, a) + indices[i] = e + end + if any(!is_resolved, indices) + error( + "For now, the indices of the first argument to `cumulative` and `density` must be resolved, got $indices", + ) + end + push!(deps, (v, Tuple(indices))) + # no need to add to arg, as the value doesn't matter + Expr(:ref, v, indices...) + else + error( + "First argument to `cumulative` and `density` must be variable, got $(arg1)", + ) + end + e, d, a = evaluate_and_track_dependencies(arg2, env) union!(deps, d) union!(args, a) - end + return Expr(:call, var.args[1], arg1, e), deps, args + else + fun_args = [] + for i in 2:length(var.args) + e, d, a = evaluate_and_track_dependencies(var.args[i], env) + push!(fun_args, e) + union!(deps, d) + union!(args, a) + end - for a in fun_args - a isa Symbol && a != :nothing && a != :(:) && (push!(deps, a); push!(args, a)) - end + for a in fun_args + a isa Symbol && + a != :nothing && + a != :(:) && + (push!(deps, a); push!(args, a)) + end - if ( - var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS || - var.args[1] ∈ (:+, :-, :*, :/, :^, :(:)) - ) && all(is_resolved, args) - return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args - else - return Expr(var.head, var.args[1], fun_args...), deps, args + if ( + var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS || + var.args[1] ∈ (:+, :-, :*, :/, :^, :(:)) + ) && all(is_resolved, args) + return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args + else + return Expr(var.head, var.args[1], fun_args...), deps, args + end end + else + error("Unexpected expression type: $var") end end @@ -691,36 +729,63 @@ end _replace_constants_in_expr(x::Number, env) = x function _replace_constants_in_expr(x::Symbol, env) - if haskey(env, x) - if env[x] isa Number # only plug in scalar variables - return env[x] - else # if it's an array, raise error because array indexing should be explicit - error("$x") - end + if haskey(env, x) && env[x] isa Number + return env[x] end return x end function _replace_constants_in_expr(x::Expr, env) - if Meta.isexpr(x, :ref) && all(x -> x isa Number, x.args[2:end]) - if haskey(env, x.args[1]) - val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...] + if Meta.isexpr(x, :ref) + v, indices... = x.args + if haskey(env, v) && all(x -> x isa Union{Int,Float64}, indices) + val = env[v][map(Int, indices)...] return ismissing(val) ? x : val + else + for i in eachindex(indices) + indices[i] = _replace_constants_in_expr(indices[i], env) + end + return Expr(:ref, v, indices...) end - else # don't try to eval the function, but try to simplify - x = deepcopy(x) # because we are mutating the args - for i in 2:length(x.args) - try - x.args[i] = _replace_constants_in_expr(x.args[i], env) - catch e - rethrow( - ErrorException( - "Array indexing in BUGS must be explicit. However, `$(e.msg)` is accessed as a scalar.", - ), + elseif Meta.isexpr(x, :call) + if x.args[1] === :cumulative || x.args[1] === :density + if length(x.args) != 3 + error( + "`cumulative` and `density` are special functions in BUGS and takes two arguments, got $(length(x.args) - 1)", ) end + f, arg1, arg2 = x.args + if arg1 isa Symbol + return Expr(:call, f, arg1, _replace_constants_in_expr(arg2, env)) + elseif Meta.isexpr(arg1, :ref) + v, indices... = arg1.args + for i in eachindex(indices) + indices[i] = _replace_constants_in_expr(indices[i], env) + end + return Expr( + :call, + f, + Expr(:ref, v, indices...), + _replace_constants_in_expr(arg2, env), + ) + else + error( + "First argument to `cumulative` and `density` must be variable, got $(x.args[2])", + ) + end + elseif x.args[1] === :deviance + @warn( + "`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function." + ) + else + x = deepcopy(x) # because we are mutating the args + for i in 2:length(x.args) + x.args[i] = _replace_constants_in_expr(x.args[i], env) + end + return x end + else + error("Unexpected expression type: $x") end - return x end """ diff --git a/src/model.jl b/src/model.jl index 65010ae34..bd64ecec4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -36,6 +36,7 @@ struct BUGSModel <: AbstractBUGSModel transformed_var_lengths::Dict{VarName,Int} # TODO: store this as a delta from `untransformed_var_lengths`? varinfo::SimpleVarInfo + distributions::Dict{VarName,Distribution} parameters::Vector{VarName} sorted_nodes::Vector{VarName} @@ -81,6 +82,7 @@ function BUGSModel( ) vs = initialize_var_store(data, vars, array_sizes) vi = SimpleVarInfo(vs, 0.0) + dist_store = Dict{VarName,Distribution}() parameters = VarName[] untransformed_param_length = 0 transformed_param_length = 0 @@ -91,11 +93,11 @@ function BUGSModel( ni = g[vn] @unpack node_type, node_function_expr, node_args = ni - args = Dict(getsym(arg) => vi[arg] for arg in node_args) + args = Dict(getsym(arg) => vi[arg] for arg in node_args) # TODO: get rid of this expr = node_function_expr.args[2] if node_type == JuliaBUGS.Logical value = try - _eval(expr, args) + _eval(expr, args, dist_store) catch e rethrow( # UninitializedVariableError( @@ -108,7 +110,7 @@ function BUGSModel( vi = setindex!!(vi, value, vn) else dist = try - _eval(expr, args) + _eval(expr, args, dist_store) catch _ rethrow( UninitializedVariableError( @@ -116,6 +118,7 @@ function BUGSModel( ), ) end + dist_store[vn] = dist value = evaluate(vn, data) # `evaluate(::VarName, env)` is defined in `src/utils.jl` if value isa Nothing # not observed push!(parameters, vn) @@ -155,6 +158,7 @@ function BUGSModel( untransformed_var_lengths, transformed_var_lengths, vi, + dist_store, parameters, sorted_nodes, g, @@ -212,7 +216,7 @@ function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo) args = Dict(getsym(arg) => vi[arg] for arg in node_args) expr = node_function_expr.args[2] if vn in m.parameters - dist = _eval(expr, args) + dist = _eval(expr, args, m.distributions) linked_val = DynamicPPL.link(dist, vi[vn]) d[vn] = linked_val end @@ -252,7 +256,7 @@ function getparams(m::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false) for v in m.parameters ni = m.g[v] args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...) - dist = _eval(ni.node_function_expr.args[2], args) + dist = _eval(ni.node_function_expr.args[2], args, m.distributions) link_vals = Bijectors.link(dist, vi[v]) len = m.transformed_var_lengths[v] @@ -288,7 +292,7 @@ function setparams!!( for v in m.parameters ni = m.g[v] args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...) - dist = _eval(ni.node_function_expr.args[2], args) + dist = _eval(ni.node_function_expr.args[2], args, m.distributions) len = if transformed m.transformed_var_lengths[v] @@ -353,6 +357,7 @@ function AbstractPPL.condition( model.untransformed_var_lengths, model.transformed_var_lengths, varinfo, + model.distributions, new_parameters, sorted_blanket_with_vars, model.g, @@ -377,6 +382,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) model.untransformed_var_lengths, model.transformed_var_lengths, model.varinfo, + model.distributions, new_parameters, sorted_blanket_with_vars, model.g, @@ -438,10 +444,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) args = Dict(getsym(arg) => vi[arg] for arg in node_args) expr = node_function_expr.args[2] if node_type == JuliaBUGS.Logical - value = _eval(expr, args) + value = _eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args) + dist = _eval(expr, args, model.distributions) value = rand(ctx.rng, dist) # just sample from the prior logp += logpdf(dist, value) vi = setindex!!(vi, value, vn) @@ -464,10 +470,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) args = Dict(getsym(arg) => vi[arg] for arg in node_args) expr = node_function_expr.args[2] if node_type == JuliaBUGS.Logical # be conservative -- always propagate values of logical nodes - value = _eval(expr, args) + value = _eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args) + dist = _eval(expr, args, model.distributions) value = vi[vn] if model.transformed # although the values stored in `vi` are in their original space, @@ -509,10 +515,10 @@ function AbstractPPL.evaluate!!( args = (; map(arg -> getsym(arg) => vi[arg], node_args)...) expr = node_function_expr.args[2] if node_type == JuliaBUGS.Logical - value = _eval(expr, args) + value = _eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args) + dist = _eval(expr, args, model.distributions) if vn in model.parameters l = var_lengths[vn] if model.transformed diff --git a/src/parser/bugs_macro.jl b/src/parser/bugs_macro.jl index 860e95259..6c963f74b 100644 --- a/src/parser/bugs_macro.jl +++ b/src/parser/bugs_macro.jl @@ -1,71 +1,5 @@ -# handle `cumulative`, `density` and `deviance` functions -# these are incorrect implementations, as it can only handle the case where the first argument is exactly the same as the LHS of the stochastic assignment -# but can't handle cases like `cumulative(y[1], x)` where `y[i]` is defined in loops -# other than that, `density` and `deviance` also require that the variable in place of first argument is observed -# TODO: fix this -function cumulative(expr::Expr) - return MacroTools.postwalk(expr) do sub_expr - if @capture(sub_expr, lhs_ = cumulative(s1_, s2_)) - dist = find_tilde_rhs(expr, s1) - sub_expr.args[2].args[1] = :cdf - sub_expr.args[2].args[2] = dist - return sub_expr - else - return sub_expr - end - end -end - -function density(expr::Expr) - return MacroTools.postwalk(expr) do sub_expr - if @capture(sub_expr, lhs_ = density(s1_, s2_)) - dist = find_tilde_rhs(expr, s1) - sub_expr.args[2].args[1] = :pdf - sub_expr.args[2].args[2] = dist - return sub_expr - else - return sub_expr - end - end -end - -function deviance(expr::Expr) - return MacroTools.postwalk(expr) do sub_expr - if @capture(sub_expr, lhs_ = deviance(s1_, s2_)) - dist = find_tilde_rhs(expr, s1) - sub_expr.args[2].args[1] = :logpdf - sub_expr.args[2].args[2] = dist - sub_expr.args[2] = Expr(:call, :*, -2, sub_expr.args[2]) - return sub_expr - else - return sub_expr - end - end -end - -function find_tilde_rhs(expr::Expr, target::Union{Expr,Symbol}) - dist = nothing - MacroTools.postwalk(expr) do sub_expr - if @capture(sub_expr, lhs_ ~ rhs_) - if lhs == target - isnothing(dist) || error("Exist two assignments to the same variable.") - dist = rhs - end - end - return sub_expr - end - isnothing(dist) && error( - "Error handling cumulative expression: can't find a stochastic assignment for $target.", - ) - return dist -end - -function handle_special_functions(expr::Expr) - return cumulative(density(deviance(expr))) -end - -macro bugs(expr) - return Meta.quot(handle_special_functions(bugs_top(expr, __source__))) +macro bugs(expr::Expr) + return Meta.quot(bugs_top(expr, __source__)) end function bugs_top(@nospecialize(expr), __source__) @@ -252,5 +186,5 @@ macro bugs(prog::String, replace_period=true, no_enclosure=false) if !isempty(error_container) # otherwise errors thrown in macro will be LoadError return :(throw(ErrorException(join($error_container, "\n")))) end - return Meta.quot(handle_special_functions(expr)) + return Meta.quot(expr) end diff --git a/src/utils.jl b/src/utils.jl index 494c79684..b0dbc6d0f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -335,16 +335,16 @@ function simple_arithmetic_eval(data::NamedTuple, expr::Expr) end """ - _eval(expr, env) + _eval(expr, env, dist_store) `_eval` mimics `Base.eval`, but uses precompiled functions. This is possible because the expressions we want to evaluate only have two kinds of expressions: function calls and indexing. `env` is a data structure mapping symbols in `expr` to values, values can be arrays or scalars. """ -function _eval(expr::Number, env) +function _eval(expr::Number, env, dist_store) return expr end -function _eval(expr::Symbol, env) +function _eval(expr::Symbol, env, dist_store) if expr == :nothing return nothing elseif expr == :(:) @@ -353,28 +353,60 @@ function _eval(expr::Symbol, env) return env[expr] end end -function _eval(expr::AbstractRange, env) +function _eval(expr::AbstractRange, env, dist_store) return expr end -function _eval(expr::Expr, env) +function _eval(expr::Expr, env, dist_store) if Meta.isexpr(expr, :call) f = expr.args[1] - args = [_eval(arg, env) for arg in expr.args[2:end]] - if f isa Expr # `JuliaBUGS.some_function` like - f = f.args[2].value + if f === :cumulative || f === :density + if length(expr.args) != 3 + error( + "density function should have 3 arguments, but get $(length(expr.args)).", + ) + end + rv1, rv2 = expr.args[2:3] + dist = if Meta.isexpr(rv1, :ref) + var, indices... = rv1.args + for i in eachindex(indices) + indices[i] = _eval(indices[i], env, dist_store) + end + vn = AbstractPPL.VarName{var}( + AbstractPPL.Setfield.IndexLens(Tuple(indices)) + ) + dist_store[vn] + elseif rv1 isa Symbol + vn = AbstractPPL.VarName{rv1}() + dist_store[vn] + else + error( + "the first argument of density function should be a variable, but got $(rv1).", + ) + end + rv2 = _eval(rv2, env, dist_store) + if f === :cumulative + return cdf(dist, rv2) + else + return pdf(dist, rv2) + end + else + args = [_eval(arg, env, dist_store) for arg in expr.args[2:end]] + if f isa Expr # `JuliaBUGS.some_function` like + f = f.args[2].value + end + return getfield(JuliaBUGS, f)(args...) # assume all functions used are available under `JuliaBUGS` end - return getfield(JuliaBUGS, f)(args...) # assume all functions used are available under `JuliaBUGS` elseif Meta.isexpr(expr, :ref) - array = _eval(expr.args[1], env) - indices = [_eval(arg, env) for arg in expr.args[2:end]] + array = _eval(expr.args[1], env, dist_store) + indices = [_eval(arg, env, dist_store) for arg in expr.args[2:end]] return array[indices...] elseif Meta.isexpr(expr, :block) - return _eval(expr.args[end], env) + return _eval(expr.args[end], env, dist_store) else error("Unknown expression type: $expr") end end -function _eval(expr, env) +function _eval(expr, env, dist_store) return error("Unknown expression type: $expr of type $(typeof(expr))") end diff --git a/test/cumulative_density.jl b/test/cumulative_density.jl new file mode 100644 index 000000000..ed9174c50 --- /dev/null +++ b/test/cumulative_density.jl @@ -0,0 +1,39 @@ +@testset "cumulative" begin + model_def = @bugs begin + a ~ Normal(0, 1) + b = cumulative(a, 2) + + c[1] ~ Normal(0, 1) + d[1] = cumulative(c[1], 2) + end + + data, inits = (;), (;) + + model = compile(model_def, data, inits) + + @test model.distributions[@varname(a)] == Normal(0, 1) + @test model.distributions[@varname(c[1])] == Normal(0, 1) + + @test model.varinfo[@varname(b)] == cdf(Normal(0, 1), 2) + @test model.varinfo[@varname(d[1])] == cdf(Normal(0, 1), 2) +end + +@testset "density" begin + model_def = @bugs begin + a ~ Normal(0, 1) + b = density(a, 2) + + c[1] ~ Normal(0, 1) + d[1] = density(c[1], 2) + end + + data, inits = (;), (;) + + model = compile(model_def, data, inits) + + @test model.distributions[@varname(a)] == Normal(0, 1) + @test model.distributions[@varname(c[1])] == Normal(0, 1) + + @test model.varinfo[@varname(b)] == pdf(Normal(0, 1), 2) + @test model.varinfo[@varname(d[1])] == pdf(Normal(0, 1), 2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 33a90e6fd..e16e438dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -84,6 +84,8 @@ else include("compile.jl") + include("cumulative_density.jl") + @testset "Compile WinBUGS Vol I examples: $m" for m in [ :blockers, :bones,