Skip to content

Commit

Permalink
fix features and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Mar 10, 2024
1 parent 8ebc608 commit 85cadfe
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 64 deletions.
1 change: 1 addition & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ function compile(model_def::Expr, data, inits; is_transformed=true)
PostChecking(data, transformed_variables), model_def, data
)
merged_data = merge_with_coalescence(deepcopy(data), transformed_variables)
model_def = concretize_colon_indexing(model_def, array_sizes, merged_data)
vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = analyze_program(
NodeFunctions(array_sizes, array_bitmap), model_def, merged_data
)
Expand Down
131 changes: 85 additions & 46 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,37 @@ function evaluate_and_track_dependencies(var::Expr, env)
push!(args, (var.args[1], ()))
push!(deps, (var.args[1], ()))
return Expr(var.head, var.args[1], idxs...), deps, args
else # function call
elseif Meta.isexpr(var, :call) && var.args[1] === :cumulative ||
var.args[1] === :density
arg1, arg2 = var.args[2:3]
arg1 = if arg1 isa Symbol
push!(deps, arg1)
push!(args, arg1)
arg1
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
e, d, a = evaluate_and_track_dependencies(indices[i], env)
union!(deps, d)
union!(args, a)
indices[i] = e
end
if any(!is_resolved, indices)
error(
"For now, the indices of the first argument to `cumulative` and `density` must be resolved, got $indices",
)
end
push!(deps, (v, Tuple(indices)))
# no need to add to arg, as the value doesn't matter
Expr(:ref, v, indices...)
else
error("First argument to `cumulative` and `density` must be variable, got $(arg1)")
end
e, d, a = evaluate_and_track_dependencies(arg2, env)
union!(deps, d)
union!(args, a)
return Expr(:call, var.args[1], arg1, e), deps, args
else
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
Expand Down Expand Up @@ -706,6 +736,32 @@ function _replace_constants_in_expr(x::Expr, env)
val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...]
return ismissing(val) ? x : val
end
elseif Meta.isexpr(x, :call) && x.args[1] === :cumulative || x.args[1] === :density
if length(x.args) != 3
error(
"`cumulative` and `density` takes two arguments, got $(length(x.args) - 1)"
)
end
if x.args[2] isa Symbol
return Expr(
:call, x.args[1], x.args[2], _replace_constants_in_expr(x.args[3], env)
)
elseif Meta.isexpr(x.args[2], :ref)
v, indices... = x.args[2].args
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(
:call,
x.args[1],
Expr(:ref, v, indices...),
_replace_constants_in_expr(x.args[3], env),
)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(x.args[2])",
)
end
else # don't try to eval the function, but try to simplify
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
Expand All @@ -723,36 +779,6 @@ function _replace_constants_in_expr(x::Expr, env)
return x
end

"""
concretize_colon_indexing(expr, array_sizes, data)
Replace all `Colon()`s in `expr` with the corresponding array size, using either the `array_sizes` or the `data` dictionaries.
# Examples
```jldoctest
julia> concretize_colon_indexing(:(f(x[1, :])), Dict(:x => (3, 4)), Dict(:x => [1 2 3 4; 5 6 7 8; 9 10 11 12]))
:(f(x[1, 1:4]))
```
"""
function concretize_colon_indexing(expr, array_sizes, data)
return MacroTools.postwalk(expr) do sub_expr
if MacroTools.@capture(sub_expr, x_[idx__])
for i in 1:length(idx)
if idx[i] == :(:)
if haskey(array_sizes, x)
idx[i] = Expr(:call, :(:), 1, array_sizes[x][i])
else
@assert haskey(data, x)
idx[i] = Expr(:call, :(:), 1, size(data[x])[i])
end
end
end
return Expr(:ref, x, idx...)
end
return sub_expr
end
end

"""
create_array_var(n, array_sizes, env)
Expand Down Expand Up @@ -787,7 +813,7 @@ try_cast_to_int(x::Real) = Int(x) # will error if !isinteger(x)
try_cast_to_int(x) = x # catch other types, e.g. UnitRange, Colon

function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
@capture(expr, lhs_expr_ ~ rhs_expr_) || @capture(expr, lhs_expr_ = rhs_expr_)
lhs_expr, rhs_expr = Meta.isexpr(expr, :(=)) ? expr.args : expr.args[2:end]
var_type = Meta.isexpr(expr, :(=)) ? Logical : Stochastic

lhs_var = find_variables_on_lhs(
Expand All @@ -798,7 +824,6 @@ function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
return nothing

pass.vars[lhs_var] = var_type
rhs_expr = concretize_colon_indexing(rhs_expr, pass.array_sizes, env)
rhs = evaluate(rhs_expr, env)

if rhs isa Symbol
Expand Down Expand Up @@ -847,21 +872,35 @@ function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
# issue is that we need to do this in steps, const propagation need to a separate pass
# otherwise the variable in previous expressions will not be evaluated to the concrete value
else
dependencies, node_args = map(
x -> map(x) do x_elem
if x_elem isa Symbol
return Var(x_elem)
elseif x_elem isa Tuple && last(x_elem) == ()
return create_array_var(first(x_elem), pass.array_sizes, env)
else
return Var(first(x_elem), last(x_elem))
end
end,
map(collect, (dependencies, node_args)),
)
dependencies = collect(dependencies)
for i in eachindex(dependencies)
if dependencies[i] isa Symbol
dependencies[i] = Var(dependencies[i])
elseif dependencies[i] isa Tuple && last(dependencies[i]) == ()
dependencies[i] = create_array_var(
first(dependencies[i]), pass.array_sizes, env
)
else
dependencies[i] = Var(first(dependencies[i]), last(dependencies[i]))
end
end

node_args = collect(node_args)
for i in eachindex(node_args)
if node_args[i] isa Symbol
node_args[i] = Var(node_args[i])
elseif node_args[i] isa Tuple && last(node_args[i]) == ()
node_args[i] = create_array_var(
first(node_args[i]), pass.array_sizes, env
)
else
node_args[i] = Var(first(node_args[i]), last(node_args[i]))
end
end

rhs_expr = MacroTools.postwalk(rhs_expr) do sub_expr
if @capture(sub_expr, arr_[idxs__])
if Meta.isexpr(sub_expr, :ref)
arr, idxs... = sub_expr.args
new_idxs = [
idx isa Integer ? idx : :(JuliaBUGS.try_cast_to_int($(idx))) for
idx in idxs
Expand Down
29 changes: 15 additions & 14 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,13 @@ function BUGSModel(

ni = g[vn]
@unpack node_type, node_function_expr, node_args = ni
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
args = Dict(getsym(arg) => vi[arg] for arg in node_args) # TODO: get rid of this
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch e
@info args expr node_args
rethrow(
# UninitializedVariableError(
# "Encounter error when evaluating the RHS of $vn. Try to initialize variables $(join(collect(keys(args)), ", ")) directly first if not yet.",
Expand All @@ -110,7 +111,7 @@ function BUGSModel(
vi = setindex!!(vi, value, vn)
else
dist = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch _
rethrow(
UninitializedVariableError(
Expand Down Expand Up @@ -192,7 +193,7 @@ function initialize_var_store(data, vars, array_sizes)
end

function initialize_distribution_store(var_store::Dict)
dist_store = Dict{Symbol, Any}()
dist_store = Dict{Symbol,Any}()
for (k, v) in var_store
if v isa AbstractArray
dist_store[AbstractPPL.getsym(k)] = Array{Distribution}(undef, size(v)...)
Expand All @@ -203,7 +204,7 @@ function initialize_distribution_store(var_store::Dict)
return NamedTuple(dist_store)
end

function get_distribution(model::BUGSModel, vn::VarName)
function get_distribution(model::BUGSModel, vn::VarName)
return AbstractPPL.Setfield.get(model.distributions, vn)
end

Expand Down Expand Up @@ -232,7 +233,7 @@ function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if vn in m.parameters
dist = _eval(expr, args)
dist = _eval(expr, args, dist_store)
linked_val = DynamicPPL.link(dist, vi[vn])
d[vn] = linked_val
end
Expand Down Expand Up @@ -272,7 +273,7 @@ function getparams(m::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false)
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, dist_store)

link_vals = Bijectors.link(dist, vi[v])
len = m.transformed_var_lengths[v]
Expand Down Expand Up @@ -308,7 +309,7 @@ function setparams!!(
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, dist_store)

len = if transformed
m.transformed_var_lengths[v]
Expand Down Expand Up @@ -458,10 +459,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, dist_store)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, dist_store)
value = rand(ctx.rng, dist) # just sample from the prior
logp += logpdf(dist, value)
vi = setindex!!(vi, value, vn)
Expand All @@ -484,10 +485,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical # be conservative -- always propagate values of logical nodes
value = _eval(expr, args)
value = _eval(expr, args, dist_store)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, dist_store)
value = vi[vn]
if model.transformed
# although the values stored in `vi` are in their original space,
Expand Down Expand Up @@ -529,10 +530,10 @@ function AbstractPPL.evaluate!!(
args = (; map(arg -> getsym(arg) => vi[arg], node_args)...)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, dist_store)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, dist_store)
if vn in model.parameters
l = var_lengths[vn]
if model.transformed
Expand Down
43 changes: 39 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,32 @@
"""
concretize_colon_indexing(expr, array_sizes, data)
Replace all `Colon()`s in `expr` with the corresponding array size, using either the `array_sizes` or the `data` dictionaries.
# Examples
```jldoctest
julia> concretize_colon_indexing(:(f(x[1, :])), Dict(:x => (3, 4)), Dict(:x => [1 2 3 4; 5 6 7 8; 9 10 11 12]))
:(f(x[1, 1:4]))
```
"""
function concretize_colon_indexing(expr, array_sizes, data)
return MacroTools.postwalk(expr) do sub_expr
if MacroTools.@capture(sub_expr, x_[idx__])
for i in 1:length(idx)
if idx[i] == :(:)
if haskey(array_sizes, x)
idx[i] = Expr(:call, :(:), 1, array_sizes[x][i])
else
idx[i] = Expr(:call, :(:), 1, size(data[x])[i])
end
end
end
return Expr(:ref, x, idx...)
end
return sub_expr
end
end

"""
decompose_for_expr(expr::Expr)
Expand Down Expand Up @@ -357,25 +386,31 @@ function _eval(expr::AbstractRange, env, dist_store)
return expr
end
function _eval(expr::Expr, env, dist_store)
if Meta.isexpr(expr, :call)
if Meta.isexpr(expr, :call)
f = expr.args[1]
if f === :cumulative || f === :density
if length(expr.args) != 3
error("density function should have 3 arguments, but get $(length(expr.args)).")
error(
"density function should have 3 arguments, but get $(length(expr.args)).",
)
end
rv1, rv2 = expr.args[2:3]
dist = if Meta.isexpr(rv1, :ref)
var, indices... = rv1.args
for i in eachindex(indices)
indices[i] = _eval(indices[i], env, dist_store)
end
vn = AbstractPPL.VarName{var}(AbstractPPL.Setfield.IndexLens(Tuple(indices)))
vn = AbstractPPL.VarName{var}(
AbstractPPL.Setfield.IndexLens(Tuple(indices))
)
AbstractPPL.Setfield.get(dist_store, vn)
elseif rv1 isa Symbol
vn = AbstractPPL.VarName{rv1}()
AbstractPPL.Setfield.get(dist_store, vn)
else
error("the first argument of density function should be a variable, but got $(rv1).")
error(
"the first argument of density function should be a variable, but got $(rv1).",
)
end
rv2 = _eval(rv2, env, dist_store)
if f === :cumulative
Expand Down
Loading

0 comments on commit 85cadfe

Please sign in to comment.