Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cumulative and density function #153

Merged
merged 29 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
15d0758
remove the `cumulative` etc handling in macro stage
sunxd3 Mar 10, 2024
9106c29
comment improvement
sunxd3 Mar 10, 2024
460d430
add `distributions` field to `BUGSModel`
sunxd3 Mar 10, 2024
8ebc608
add handling code for cumulative/density
sunxd3 Mar 10, 2024
85cadfe
fix features and add tests
sunxd3 Mar 10, 2024
d80c3ca
add warn message
sunxd3 Mar 10, 2024
4d0d686
use Dict instead of namedtuple for `distributions`
sunxd3 Mar 10, 2024
5b0a29a
Update src/model.jl
sunxd3 Mar 10, 2024
80fe929
fix typo
sunxd3 Mar 10, 2024
c683ac1
update tests
sunxd3 Mar 10, 2024
6ae5f0c
fix typo
sunxd3 Mar 10, 2024
ff24143
update condition function
sunxd3 Mar 10, 2024
a5fccf1
remove `info` call
sunxd3 Mar 11, 2024
da660e4
fix bug
sunxd3 Mar 11, 2024
d371c21
recover old imple
sunxd3 Mar 11, 2024
c2410e7
some minor improvement
sunxd3 Mar 12, 2024
62dff76
Apply suggestions from code review
sunxd3 Mar 12, 2024
4bf4627
minor improvement on program logic
sunxd3 Mar 12, 2024
4d401a3
add error handling
sunxd3 Mar 12, 2024
c156c53
put back some changes
sunxd3 Mar 12, 2024
afb5744
fix error
sunxd3 Mar 12, 2024
0ec7ebb
format
sunxd3 Mar 12, 2024
4adc8a2
use NamedTuple in place of Dict in some tests
sunxd3 Mar 12, 2024
31bd17c
fix error
sunxd3 Mar 12, 2024
93f5c1c
remove some changes
sunxd3 Mar 12, 2024
d95377f
remove unnecessary changes
sunxd3 Mar 12, 2024
a672a7a
formatting
sunxd3 Mar 12, 2024
876672b
more formatting
sunxd3 Mar 12, 2024
7a419e5
Merge branch 'master' into sunxd/fix_cumulative
sunxd3 Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 101 additions & 36 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,27 +631,65 @@
push!(args, (var.args[1], ()))
push!(deps, (var.args[1], ()))
return Expr(var.head, var.args[1], idxs...), deps, args
else # function call
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
elseif Meta.isexpr(var, :call)
if var.args[1] === :cumulative || var.args[1] === :density
arg1, arg2 = var.args[2:3]
arg1 = if arg1 isa Symbol
push!(deps, arg1)
# no need to add to arg, as the value doesn't matter
arg1
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
e, d, a = evaluate_and_track_dependencies(indices[i], env)
union!(deps, d)
union!(args, a)
indices[i] = e
end
if any(!is_resolved, indices)
error(

Check warning on line 650 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L650

Added line #L650 was not covered by tests
"For now, the indices of the first argument to `cumulative` and `density` must be resolved, got $indices",
)
end
push!(deps, (v, Tuple(indices)))
# no need to add to arg, as the value doesn't matter
Expr(:ref, v, indices...)
else
error(
"First argument to `cumulative` and `density` must be variable, got $(arg1)",
)
end
e, d, a = evaluate_and_track_dependencies(arg2, env)
union!(deps, d)
union!(args, a)
end
return Expr(:call, var.args[1], arg1, e), deps, args
else
fun_args = []
for i in 2:length(var.args)
e, d, a = evaluate_and_track_dependencies(var.args[i], env)
push!(fun_args, e)
union!(deps, d)
union!(args, a)
end

for a in fun_args
a isa Symbol && a != :nothing && a != :(:) && (push!(deps, a); push!(args, a))
end
for a in fun_args
a isa Symbol &&
a != :nothing &&
a != :(:) &&
(push!(deps, a); push!(args, a))
end

if (
var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] ∈ (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
if (
var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] ∈ (:+, :-, :*, :/, :^, :(:))
) && all(is_resolved, args)
return getfield(JuliaBUGS, var.args[1])(fun_args...), deps, args
else
return Expr(var.head, var.args[1], fun_args...), deps, args
end
end
else
error("Unexpected expression type: $var")

Check warning on line 692 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L692

Added line #L692 was not covered by tests
end
end

Expand Down Expand Up @@ -691,36 +729,63 @@

_replace_constants_in_expr(x::Number, env) = x
function _replace_constants_in_expr(x::Symbol, env)
if haskey(env, x)
if env[x] isa Number # only plug in scalar variables
return env[x]
else # if it's an array, raise error because array indexing should be explicit
error("$x")
end
if haskey(env, x) && env[x] isa Number
return env[x]
end
return x
end
function _replace_constants_in_expr(x::Expr, env)
if Meta.isexpr(x, :ref) && all(x -> x isa Number, x.args[2:end])
if haskey(env, x.args[1])
val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...]
if Meta.isexpr(x, :ref)
v, indices... = x.args
if haskey(env, v) && all(x -> x isa Union{Int,Float64}, indices)
val = env[v][map(Int, indices)...]
return ismissing(val) ? x : val
else
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(:ref, v, indices...)
end
else # don't try to eval the function, but try to simplify
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
try
x.args[i] = _replace_constants_in_expr(x.args[i], env)
catch e
rethrow(
ErrorException(
"Array indexing in BUGS must be explicit. However, `$(e.msg)` is accessed as a scalar.",
),
elseif Meta.isexpr(x, :call)
if x.args[1] === :cumulative || x.args[1] === :density
if length(x.args) != 3
error(

Check warning on line 752 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L752

Added line #L752 was not covered by tests
"`cumulative` and `density` are special functions in BUGS and takes two arguments, got $(length(x.args) - 1)",
)
end
f, arg1, arg2 = x.args
if arg1 isa Symbol
return Expr(:call, f, arg1, _replace_constants_in_expr(arg2, env))
elseif Meta.isexpr(arg1, :ref)
v, indices... = arg1.args
for i in eachindex(indices)
indices[i] = _replace_constants_in_expr(indices[i], env)
end
return Expr(
:call,
f,
Expr(:ref, v, indices...),
_replace_constants_in_expr(arg2, env),
)
else
error(

Check warning on line 771 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L771

Added line #L771 was not covered by tests
"First argument to `cumulative` and `density` must be variable, got $(x.args[2])",
)
end
elseif x.args[1] === :deviance
@warn(

Check warning on line 776 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L776

Added line #L776 was not covered by tests
"`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function."
)
else
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
x.args[i] = _replace_constants_in_expr(x.args[i], env)
end
return x
end
else
error("Unexpected expression type: $x")

Check warning on line 787 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L787

Added line #L787 was not covered by tests
end
return x
end

"""
Expand Down
30 changes: 18 additions & 12 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct BUGSModel <: AbstractBUGSModel
transformed_var_lengths::Dict{VarName,Int} # TODO: store this as a delta from `untransformed_var_lengths`?

varinfo::SimpleVarInfo
distributions::Dict{VarName,Distribution}
parameters::Vector{VarName}
sorted_nodes::Vector{VarName}

Expand Down Expand Up @@ -81,6 +82,7 @@ function BUGSModel(
)
vs = initialize_var_store(data, vars, array_sizes)
vi = SimpleVarInfo(vs, 0.0)
dist_store = Dict{VarName,Distribution}()
parameters = VarName[]
untransformed_param_length = 0
transformed_param_length = 0
Expand All @@ -91,11 +93,11 @@ function BUGSModel(

ni = g[vn]
@unpack node_type, node_function_expr, node_args = ni
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
args = Dict(getsym(arg) => vi[arg] for arg in node_args) # TODO: get rid of this
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch e
rethrow(
# UninitializedVariableError(
Expand All @@ -108,14 +110,15 @@ function BUGSModel(
vi = setindex!!(vi, value, vn)
else
dist = try
_eval(expr, args)
_eval(expr, args, dist_store)
catch _
rethrow(
UninitializedVariableError(
"Encounter support error when evaluating the distribution of $vn. Try to initialize variables $(join(collect(keys(args)), ", ")) first if not yet.",
),
)
end
dist_store[vn] = dist
value = evaluate(vn, data) # `evaluate(::VarName, env)` is defined in `src/utils.jl`
if value isa Nothing # not observed
push!(parameters, vn)
Expand Down Expand Up @@ -155,6 +158,7 @@ function BUGSModel(
untransformed_var_lengths,
transformed_var_lengths,
vi,
dist_store,
parameters,
sorted_nodes,
g,
Expand Down Expand Up @@ -212,7 +216,7 @@ function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if vn in m.parameters
dist = _eval(expr, args)
dist = _eval(expr, args, m.distributions)
linked_val = DynamicPPL.link(dist, vi[vn])
d[vn] = linked_val
end
Expand Down Expand Up @@ -252,7 +256,7 @@ function getparams(m::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false)
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, m.distributions)

link_vals = Bijectors.link(dist, vi[v])
len = m.transformed_var_lengths[v]
Expand Down Expand Up @@ -288,7 +292,7 @@ function setparams!!(
for v in m.parameters
ni = m.g[v]
args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...)
dist = _eval(ni.node_function_expr.args[2], args)
dist = _eval(ni.node_function_expr.args[2], args, m.distributions)

len = if transformed
m.transformed_var_lengths[v]
Expand Down Expand Up @@ -353,6 +357,7 @@ function AbstractPPL.condition(
model.untransformed_var_lengths,
model.transformed_var_lengths,
varinfo,
model.distributions,
new_parameters,
sorted_blanket_with_vars,
model.g,
Expand All @@ -377,6 +382,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
model.untransformed_var_lengths,
model.transformed_var_lengths,
model.varinfo,
model.distributions,
new_parameters,
sorted_blanket_with_vars,
model.g,
Expand Down Expand Up @@ -438,10 +444,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
value = rand(ctx.rng, dist) # just sample from the prior
logp += logpdf(dist, value)
vi = setindex!!(vi, value, vn)
Expand All @@ -464,10 +470,10 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
args = Dict(getsym(arg) => vi[arg] for arg in node_args)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical # be conservative -- always propagate values of logical nodes
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
value = vi[vn]
if model.transformed
# although the values stored in `vi` are in their original space,
Expand Down Expand Up @@ -509,10 +515,10 @@ function AbstractPPL.evaluate!!(
args = (; map(arg -> getsym(arg) => vi[arg], node_args)...)
expr = node_function_expr.args[2]
if node_type == JuliaBUGS.Logical
value = _eval(expr, args)
value = _eval(expr, args, model.distributions)
vi = setindex!!(vi, value, vn)
else
dist = _eval(expr, args)
dist = _eval(expr, args, model.distributions)
if vn in model.parameters
l = var_lengths[vn]
if model.transformed
Expand Down
72 changes: 3 additions & 69 deletions src/parser/bugs_macro.jl
Original file line number Diff line number Diff line change
@@ -1,71 +1,5 @@
# handle `cumulative`, `density` and `deviance` functions
# these are incorrect implementations, as it can only handle the case where the first argument is exactly the same as the LHS of the stochastic assignment
# but can't handle cases like `cumulative(y[1], x)` where `y[i]` is defined in loops
# other than that, `density` and `deviance` also require that the variable in place of first argument is observed
# TODO: fix this
function cumulative(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = cumulative(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :cdf
sub_expr.args[2].args[2] = dist
return sub_expr
else
return sub_expr
end
end
end

function density(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = density(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :pdf
sub_expr.args[2].args[2] = dist
return sub_expr
else
return sub_expr
end
end
end

function deviance(expr::Expr)
return MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ = deviance(s1_, s2_))
dist = find_tilde_rhs(expr, s1)
sub_expr.args[2].args[1] = :logpdf
sub_expr.args[2].args[2] = dist
sub_expr.args[2] = Expr(:call, :*, -2, sub_expr.args[2])
return sub_expr
else
return sub_expr
end
end
end

function find_tilde_rhs(expr::Expr, target::Union{Expr,Symbol})
dist = nothing
MacroTools.postwalk(expr) do sub_expr
if @capture(sub_expr, lhs_ ~ rhs_)
if lhs == target
isnothing(dist) || error("Exist two assignments to the same variable.")
dist = rhs
end
end
return sub_expr
end
isnothing(dist) && error(
"Error handling cumulative expression: can't find a stochastic assignment for $target.",
)
return dist
end

function handle_special_functions(expr::Expr)
return cumulative(density(deviance(expr)))
end

macro bugs(expr)
return Meta.quot(handle_special_functions(bugs_top(expr, __source__)))
macro bugs(expr::Expr)
return Meta.quot(bugs_top(expr, __source__))
end

function bugs_top(@nospecialize(expr), __source__)
Expand Down Expand Up @@ -252,5 +186,5 @@ macro bugs(prog::String, replace_period=true, no_enclosure=false)
if !isempty(error_container) # otherwise errors thrown in macro will be LoadError
return :(throw(ErrorException(join($error_container, "\n"))))
end
return Meta.quot(handle_special_functions(expr))
return Meta.quot(expr)
end
Loading
Loading