Skip to content

Commit

Permalink
Fix cumulative and density function (#153)
Browse files Browse the repository at this point in the history
* `deviance` function is not supported and user will be warned when
using it.
* Compare to `multibugs`: multibugs doesn't seem to support `cumulative`
and `density`, but provide individual cdf and pdf functions. We can also
stop supporting these two functions, and direct user to use `cdf` and
`pdf`.
* For now, a new member `distributions` is added to `BUGSModel` to track
all the distributions: using Dict right now because non-scalar rv is
problematic if we store distributions in the same way we store values.
`cumulative` and `density` are handled in special way in `_eval`
function to allow access to distribution info at run time. Maybe this
can be generalized, the distribution can be helpful introspect info

Fix: #137

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sunxd3 and github-actions[bot] authored Mar 12, 2024
1 parent fe710e6 commit 556c5be
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 130 deletions.
137 changes: 101 additions & 36 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,27 +631,65 @@ function evaluate_and_track_dependencies(var::Expr, env)
push!(args, (var.args[1], ()))
push!(deps, (var.args[1], ()))
return Expr(var.head, var.args[1], idxs...), deps, args
else # function call
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
elseif Meta.isexpr(var, :call)
if var.args[1] === :cumulative || var.args[1] === :density
arg1, arg2 = var.args[2:3]
arg1 = if arg1 isa Symbol
push!(deps, arg1)
# no need to add to arg, as the value doesn't matter
arg1
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
e, d, a = evaluate_and_track_dependencies(indices[i], env)
union!(deps, d)
union!(args, a)
indices[i] = e
end
if any(!is_resolved, indices)
error(
"For now, the indices of the first argument to `cumulative` and `density` must be resolved, got $indices",
)
end
push!(deps, (v, Tuple(indices)))
# no need to add to arg, as the value doesn't matter
Expr(:ref, v, indices...)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(arg1)",
)
end
e, d, a = evaluate_and_track_dependencies(arg2, env)
union!(deps, d)
union!(args, a)
end
return Expr(:call, var.args[1], arg1, e), deps, args
else
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
union!(deps, d)
union!(args, a)
end

for a in fun_args
a isa Symbol && a != :nothing && a != :(:) && (push!(deps, a); push!(args, a))
end
for a in fun_args
a isa Symbol &&
a != :nothing &&
a != :(:) &&
(push!(deps, a); push!(args, a))
end

if (
var.args[1] BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
if (
var.args[1] BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
end
end
else
error("Unexpected expression type: $var")
end
end

Expand Down Expand Up @@ -691,36 +729,63 @@ end

_replace_constants_in_expr(x::Number, env) = x
function _replace_constants_in_expr(x::Symbol, env)
if haskey(env, x)
if env[x] isa Number # only plug in scalar variables
return env[x]
else # if it's an array, raise error because array indexing should be explicit
error("$x")
end
if haskey(env, x) && env[x] isa Number
return env[x]
end
return x
end
function _replace_constants_in_expr(x::Expr, env)
if Meta.isexpr(x, :ref) && all(x -> x isa Number, x.args[2:end])
if haskey(env, x.args[1])
val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...]
if Meta.isexpr(x, :ref)
v, indices... = x.args
if haskey(env, v) && all(x -> x isa Union{Int,Float64}, indices)
val = env[v][map(Int, indices)...]
return ismissing(val) ? x : val
else
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(:ref, v, indices...)
end
else # don't try to eval the function, but try to simplify
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
try
x.args[i] = _replace_constants_in_expr(x.args[i], env)
catch e
rethrow(
ErrorException(
"Array indexing in BUGS must be explicit. However, `$(e.msg)` is accessed as a scalar.",
),
elseif Meta.isexpr(x, :call)
if x.args[1] === :cumulative || x.args[1] === :density
if length(x.args) != 3
error(
"`cumulative` and `density` are special functions in BUGS and takes two arguments, got $(length(x.args) - 1)",
)
end
f, arg1, arg2 = x.args
if arg1 isa Symbol
return Expr(:call, f, arg1, _replace_constants_in_expr(arg2, env))
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(
:call,
f,
Expr(:ref, v, indices...),
_replace_constants_in_expr(arg2, env),
)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(x.args[2])",
)
end
elseif x.args[1] === :deviance
@warn(
"`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function."
)
else
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
x.args[i] = _replace_constants_in_expr(x.args[i], env)
end
return x
end
else
error("Unexpected expression type: $x")
end
return x
end

"""
Expand Down
30 changes: 18 additions & 12 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct BUGSModel <: AbstractBUGSModel
transformed_var_lengths::Dict{VarName,Int} # TODO: store this as a delta from `untransformed_var_lengths`?

varinfo::SimpleVarInfo
distributions::Dict{VarName,Distribution}
parameters::Vector{VarName}
sorted_nodes::Vector{VarName}

Expand Down Expand Up @@ -81,6 +82,7 @@ function BUGSModel(
)
vs = initialize_var_store(data, vars, array_sizes)
vi = SimpleVarInfo(vs, 0.0)
dist_store = Dict{VarName,Distribution}()
parameters = VarName[]
untransformed_param_length = 0
transformed_param_length = 0
Expand All @@ -91,11 +93,11 @@ function BUGSModel(

ni = g[vn]
@unpack node_type, node_function_expr, node_args = ni
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
args = Dict(getsym(arg) => vi[arg] for arg in node_args) # TODO: get rid of this
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch e
rethrow(
# UninitializedVariableError(
Expand All @@ -108,14 +110,15 @@ function BUGSModel(
vi = setindex!!(vi, value, vn)
else
dist = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch _
rethrow(
UninitializedVariableError(
"Encounter support error when evaluating the distribution of $vn. Try to initialize variables $(join(collect(keys(args)), ", ")) first if not yet.",
),
)
end
dist_store[vn] = dist
value = evaluate(vn, data) # `evaluate(::VarName, env)` is defined in `src/utils.jl`
if value isa Nothing # not observed
push!(parameters, vn)
Expand Down Expand Up @@ -155,6 +158,7 @@ function BUGSModel(
untransformed_var_lengths,
transformed_var_lengths,
vi,
dist_store,
parameters,
sorted_nodes,
g,
Expand Down Expand Up @@ -212,7 +216,7 @@ function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if vn in m.parameters
dist = _eval(expr, args)
dist = _eval(expr, args, m.distributions)
linked_val = DynamicPPL.link(dist, vi[vn])
d[vn] = linked_val
end
Expand Down Expand Up @@ -252,7 +256,7 @@ function getparams(m::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false)
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, m.distributions)

link_vals = Bijectors.link(dist, vi[v])
len = m.transformed_var_lengths[v]
Expand Down Expand Up @@ -288,7 +292,7 @@ function setparams!!(
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, m.distributions)

len = if transformed
m.transformed_var_lengths[v]
Expand Down Expand Up @@ -353,6 +357,7 @@ function AbstractPPL.condition(
model.untransformed_var_lengths,
model.transformed_var_lengths,
varinfo,
model.distributions,
new_parameters,
sorted_blanket_with_vars,
model.g,
Expand All @@ -377,6 +382,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
model.untransformed_var_lengths,
model.transformed_var_lengths,
model.varinfo,
model.distributions,
new_parameters,
sorted_blanket_with_vars,
model.g,
Expand Down Expand Up @@ -438,10 +444,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
value = rand(ctx.rng, dist) # just sample from the prior
logp += logpdf(dist, value)
vi = setindex!!(vi, value, vn)
Expand All @@ -464,10 +470,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical # be conservative -- always propagate values of logical nodes
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
value = vi[vn]
if model.transformed
# although the values stored in `vi` are in their original space,
Expand Down Expand Up @@ -509,10 +515,10 @@ function AbstractPPL.evaluate!!(
args = (; map(arg -> getsym(arg) => vi[arg], node_args)...)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
if vn in model.parameters
l = var_lengths[vn]
if model.transformed
Expand Down
72 changes: 3 additions & 69 deletions src/parser/bugs_macro.jl
Original file line number Diff line number Diff line change
@@ -1,71 +1,5 @@
# handle `cumulative`, `density` and `deviance` functions
# these are incorrect implementations, as it can only handle the case where the first argument is exactly the same as the LHS of the stochastic assignment
# but can't handle cases like `cumulative(y[1], x)` where `y[i]` is defined in loops
# other than that, `density` and `deviance` also require that the variable in place of first argument is observed
# TODO: fix this
function cumulative(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = cumulative(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :cdf
sub_expr.args[2].args[2] = dist
return sub_expr
else
return sub_expr
end
end
end

function density(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = density(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :pdf
sub_expr.args[2].args[2] = dist
return sub_expr
else
return sub_expr
end
end
end

function deviance(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = deviance(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :logpdf
sub_expr.args[2].args[2] = dist
sub_expr.args[2] = Expr(:call, :*, -2, sub_expr.args[2])
return sub_expr
else
return sub_expr
end
end
end

function find_tilde_rhs(expr::Expr, target::Union{Expr,Symbol})
dist = nothing
MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ ~ rhs_)
if lhs == target
isnothing(dist) || error("Exist two assignments to the same variable.")
dist = rhs
end
end
return sub_expr
end
isnothing(dist) && error(
"Error handling cumulative expression: can't find a stochastic assignment for $target.",
)
return dist
end

function handle_special_functions(expr::Expr)
return cumulative(density(deviance(expr)))
end

macro bugs(expr)
return Meta.quot(handle_special_functions(bugs_top(expr, __source__)))
macro bugs(expr::Expr)
return Meta.quot(bugs_top(expr, __source__))
end

function bugs_top(@nospecialize(expr), __source__)
Expand Down Expand Up @@ -252,5 +186,5 @@ macro bugs(prog::String, replace_period=true, no_enclosure=false)
if !isempty(error_container) # otherwise errors thrown in macro will be LoadError
return :(throw(ErrorException(join($error_container, "\n"))))
end
return Meta.quot(handle_special_functions(expr))
return Meta.quot(expr)
end
Loading

0 comments on commit 556c5be

Please sign in to comment.