Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cumulative and density function #153

Merged
merged 29 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
15d0758
remove the `cumulative` etc handling in macro stage
sunxd3 Mar 10, 2024
9106c29
comment improvement
sunxd3 Mar 10, 2024
460d430
add `distributions` field to `BUGSModel`
sunxd3 Mar 10, 2024
8ebc608
add handling code for cumulative/density
sunxd3 Mar 10, 2024
85cadfe
fix features and add tests
sunxd3 Mar 10, 2024
d80c3ca
add warn message
sunxd3 Mar 10, 2024
4d0d686
use Dict instead of namedtuple for `distributions`
sunxd3 Mar 10, 2024
5b0a29a
Update src/model.jl
sunxd3 Mar 10, 2024
80fe929
fix typo
sunxd3 Mar 10, 2024
c683ac1
update tests
sunxd3 Mar 10, 2024
6ae5f0c
fix typo
sunxd3 Mar 10, 2024
ff24143
update condition function
sunxd3 Mar 10, 2024
a5fccf1
remove `info` call
sunxd3 Mar 11, 2024
da660e4
fix bug
sunxd3 Mar 11, 2024
d371c21
recover old imple
sunxd3 Mar 11, 2024
c2410e7
some minor improvement
sunxd3 Mar 12, 2024
62dff76
Apply suggestions from code review
sunxd3 Mar 12, 2024
4bf4627
minor improvement on program logic
sunxd3 Mar 12, 2024
4d401a3
add error handling
sunxd3 Mar 12, 2024
c156c53
put back some changes
sunxd3 Mar 12, 2024
afb5744
fix error
sunxd3 Mar 12, 2024
0ec7ebb
format
sunxd3 Mar 12, 2024
4adc8a2
use NamedTuple in place of Dict in some tests
sunxd3 Mar 12, 2024
31bd17c
fix error
sunxd3 Mar 12, 2024
93f5c1c
remove some changes
sunxd3 Mar 12, 2024
d95377f
remove unnecessary changes
sunxd3 Mar 12, 2024
a672a7a
formatting
sunxd3 Mar 12, 2024
876672b
more formatting
sunxd3 Mar 12, 2024
7a419e5
Merge branch 'master' into sunxd/fix_cumulative
sunxd3 Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,19 @@ Compile a BUGS model into a log density problem.
- A [`BUGSModel`](@ref) object representing the compiled model.
"""
function compile(model_def::Expr, data, inits; is_transformed=true)
if !(data isa NamedTuple)
if !(data isa NamedTuple) && !(data isa Dict{Symbol,<:Any})
error(
"Data must be a NamedTuple or a Dict{Symbol,<:Any}. Received: $(typeof(data))"
)
elseif data isa Dict{Symbol,<:Any}
data = NamedTuple{Tuple(keys(data))}(values(data))
end

if !(inits isa NamedTuple)
if !(inits isa NamedTuple) && !(inits isa Dict{Symbol,<:Any})
error(
"Initializations must be a NamedTuple or a Dict{Symbol,<:Any}. Received: $(typeof(inits))",
)
elseif inits isa Dict{Symbol,<:Any}
inits = NamedTuple{Tuple(keys(inits))}(values(inits))
end

Expand All @@ -256,6 +264,7 @@ function compile(model_def::Expr, data, inits; is_transformed=true)
PostChecking(data, transformed_variables), model_def, data
)
merged_data = merge_with_coalescence(deepcopy(data), transformed_variables)
model_def = concretize_colon_indexing(model_def, array_sizes, merged_data)
vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = analyze_program(
NodeFunctions(array_sizes, array_bitmap), model_def, merged_data
)
Expand Down
181 changes: 129 additions & 52 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,27 +631,65 @@ function evaluate_and_track_dependencies(var::Expr, env)
push!(args, (var.args[1], ()))
push!(deps, (var.args[1], ()))
return Expr(var.head, var.args[1], idxs...), deps, args
else # function call
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
elseif Meta.isexpr(var, :call)
if var.args[1] === :cumulative || var.args[1] === :density
arg1, arg2 = var.args[2:3]
arg1 = if arg1 isa Symbol
push!(deps, arg1)
# no need to add to arg, as the value doesn't matter
arg1
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
e, d, a = evaluate_and_track_dependencies(indices[i], env)
union!(deps, d)
union!(args, a)
indices[i] = e
end
if any(!is_resolved, indices)
error(
"For now, the indices of the first argument to `cumulative` and `density` must be resolved, got $indices",
)
end
push!(deps, (v, Tuple(indices)))
# no need to add to arg, as the value doesn't matter
Expr(:ref, v, indices...)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(arg1)",
)
end
e, d, a = evaluate_and_track_dependencies(arg2, env)
union!(deps, d)
union!(args, a)
end
return Expr(:call, var.args[1], arg1, e), deps, args
else
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
union!(deps, d)
union!(args, a)
end

for a in fun_args
a isa Symbol && a != :nothing && a != :(:) && (push!(deps, a); push!(args, a))
end
for a in fun_args
a isa Symbol &&
a != :nothing &&
a != :(:) &&
(push!(deps, a); push!(args, a))
end

if (
var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] ∈ (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
if (
var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] ∈ (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
end
end
else
error("Unexpected expression type: $var")
end
end

Expand Down Expand Up @@ -691,36 +729,63 @@ end

_replace_constants_in_expr(x::Number, env) = x
function _replace_constants_in_expr(x::Symbol, env)
if haskey(env, x)
if env[x] isa Number # only plug in scalar variables
return env[x]
else # if it's an array, raise error because array indexing should be explicit
error("$x")
end
if haskey(env, x) && env[x] isa Number
return env[x]
end
return x
end
function _replace_constants_in_expr(x::Expr, env)
if Meta.isexpr(x, :ref) && all(x -> x isa Number, x.args[2:end])
if haskey(env, x.args[1])
val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...]
if Meta.isexpr(x, :ref)
v, indices... = x.args
if haskey(env, v) && all(x -> x isa Union{Int,Float64}, indices)
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
val = env[v][map(Int, indices)...]
return ismissing(val) ? x : val
else
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(:ref, v, indices...)
end
else # don't try to eval the function, but try to simplify
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
try
x.args[i] = _replace_constants_in_expr(x.args[i], env)
catch e
rethrow(
ErrorException(
"Array indexing in BUGS must be explicit. However, `$(e.msg)` is accessed as a scalar.",
),
elseif Meta.isexpr(x, :call)
if x.args[1] === :cumulative || x.args[1] === :density
if length(x.args) != 3
error(
"`cumulative` and `density` are special functions in BUGS and takes two arguments, got $(length(x.args) - 1)",
)
end
f, arg1, arg2 = x.args
if arg1 isa Symbol
return Expr(:call, f, arg1, _replace_constants_in_expr(arg2, env))
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(
:call,
f,
Expr(:ref, v, indices...),
_replace_constants_in_expr(arg2, env),
)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(x.args[2])",
)
end
elseif x.args[1] === :deviance
@warn(
"`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function."
)
else
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
x.args[i] = _replace_constants_in_expr(x.args[i], env)
end
return x
end
else
error("Unexpected expression type: $x")
end
return x
end

"""
Expand All @@ -742,7 +807,6 @@ function concretize_colon_indexing(expr, array_sizes, data)
if haskey(array_sizes, x)
idx[i] = Expr(:call, :(:), 1, array_sizes[x][i])
else
@assert haskey(data, x)
idx[i] = Expr(:call, :(:), 1, size(data[x])[i])
end
end
Expand Down Expand Up @@ -787,7 +851,7 @@ try_cast_to_int(x::Real) = Int(x) # will error if !isinteger(x)
try_cast_to_int(x) = x # catch other types, e.g. UnitRange, Colon

function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
@capture(expr, lhs_expr_ ~ rhs_expr_) || @capture(expr, lhs_expr_ = rhs_expr_)
lhs_expr, rhs_expr = Meta.isexpr(expr, :(=)) ? expr.args : expr.args[2:end]
var_type = Meta.isexpr(expr, :(=)) ? Logical : Stochastic

lhs_var = find_variables_on_lhs(
Expand All @@ -798,7 +862,6 @@ function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
return nothing

pass.vars[lhs_var] = var_type
rhs_expr = concretize_colon_indexing(rhs_expr, pass.array_sizes, env)
rhs = evaluate(rhs_expr, env)

if rhs isa Symbol
Expand Down Expand Up @@ -847,21 +910,35 @@ function analyze_assignment(pass::NodeFunctions, expr::Expr, env::NamedTuple)
# issue is that we need to do this in steps, const propagation need to a separate pass
# otherwise the variable in previous expressions will not be evaluated to the concrete value
else
dependencies, node_args = map(
x -> map(x) do x_elem
if x_elem isa Symbol
return Var(x_elem)
elseif x_elem isa Tuple && last(x_elem) == ()
return create_array_var(first(x_elem), pass.array_sizes, env)
else
return Var(first(x_elem), last(x_elem))
end
end,
map(collect, (dependencies, node_args)),
)
dependencies = collect(dependencies)
for i in eachindex(dependencies)
if dependencies[i] isa Symbol
dependencies[i] = Var(dependencies[i])
elseif dependencies[i] isa Tuple && last(dependencies[i]) == ()
dependencies[i] = create_array_var(
first(dependencies[i]), pass.array_sizes, env
)
else
dependencies[i] = Var(first(dependencies[i]), last(dependencies[i]))
end
end

node_args = collect(node_args)
for i in eachindex(node_args)
if node_args[i] isa Symbol
node_args[i] = Var(node_args[i])
elseif node_args[i] isa Tuple && last(node_args[i]) == ()
node_args[i] = create_array_var(
first(node_args[i]), pass.array_sizes, env
)
else
node_args[i] = Var(first(node_args[i]), last(node_args[i]))
end
end

rhs_expr = MacroTools.postwalk(rhs_expr) do sub_expr
if @capture(sub_expr, arr_[idxs__])
if Meta.isexpr(sub_expr, :ref)
arr, idxs... = sub_expr.args
new_idxs = [
idx isa Integer ? idx : :(JuliaBUGS.try_cast_to_int($(idx))) for
idx in idxs
Expand Down
Loading
Loading