Skip to content

Commit

Permalink
Cleanup some unused code; add docstrings for some functions and struc…
Browse files Browse the repository at this point in the history
…ts (#164)
  • Loading branch information
sunxd3 authored Mar 20, 2024
1 parent 85c4490 commit 4f4772a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 77 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.5.0"
version = "0.5.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
66 changes: 41 additions & 25 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,42 @@ function evaluate_and_track_dependencies(var::Expr, env)
end
end

"""
AddVertices
This pass will add a vertex for every instance of LHS in the model.
The node functions are the same for all the nodes whose corresponding LHS are originated from the same statement.
The values of loop variables at the time LHS is evaluated will be saved.
`vertex_id_tracker` tracks the vertex ID of each variable in the model. This is used to efficiently decide target
vertices in pass `AddEdges`.
"""
mutable struct AddVertices <: CompilerPass
const env::NamedTuple
const g::MetaGraph
vertex_id_tracker::NamedTuple
const f_dict::Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}
end

function AddVertices(model_def::Expr, eval_env::NamedTuple)
g = MetaGraph(DiGraph(); label_type=VarName, vertex_data_type=NodeInfo)
vertex_id_tracker = Dict{Symbol,Any}()
for (k, v) in pairs(eval_env)
if v isa AbstractArray
vertex_id_tracker[k] = zeros(Int, size(v))
else
vertex_id_tracker[k] = 0
end
end

f_dict = build_node_functions(
model_def, eval_env, Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}(), ()
)

return AddVertices(eval_env, g, NamedTuple(vertex_id_tracker), f_dict)
end

function build_node_functions(
expr::Expr,
eval_env::NamedTuple,
Expand Down Expand Up @@ -671,31 +707,6 @@ function make_function_expr(expr, env::NamedTuple{vars}) where {vars}
end
end

mutable struct AddVertices <: CompilerPass
const env::NamedTuple
const g::MetaGraph
vertex_id_tracker::NamedTuple
const f_dict::Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}
end

function AddVertices(model_def::Expr, eval_env::NamedTuple)
g = MetaGraph(DiGraph(); label_type=VarName, vertex_data_type=NodeInfo)
vertex_id_tracker = Dict{Symbol,Any}()
for (k, v) in pairs(eval_env)
if v isa AbstractArray
vertex_id_tracker[k] = zeros(Int, size(v))
else
vertex_id_tracker[k] = 0
end
end

f_dict = build_node_functions(
model_def, eval_env, Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}(), ()
)

return AddVertices(eval_env, g, NamedTuple(vertex_id_tracker), f_dict)
end

function analyze_statement(pass::AddVertices, expr::Expr, loop_vars::NamedTuple)
lhs_expr = is_deterministic(expr) ? expr.args[1] : expr.args[2]
env = merge(pass.env, loop_vars)
Expand Down Expand Up @@ -750,6 +761,11 @@ function analyze_statement(pass::AddVertices, expr::Expr, loop_vars::NamedTuple)
end
end

"""
AddEdges
This pass will add edges to the graph constructed in pass `AddVertices`.
"""
struct AddEdges <: CompilerPass
env::NamedTuple
g::MetaGraph
Expand Down
76 changes: 25 additions & 51 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
create_eval_env(non_data_scalars, non_data_array_sizes, data)
Constructs an `NamedTuple` containing all the variables defined or used in the program.
Arrays given by data will only be copied if they contain `missing` values. This copy behavior ensures
that the evaluation environment is a self-contained snapshot, avoiding unintended side effects on the input data.
Variables not given by data will be assigned `missing` values.
"""
function create_eval_env(
non_data_scalars::Tuple{Vararg{Symbol}},
non_data_array_sizes::NamedTuple{non_data_array_vars},
Expand Down Expand Up @@ -30,6 +40,20 @@ function create_eval_env(
return NamedTuple(eval_env)
end

"""
concretize_eval_env(eval_env::NamedTuple)
For arrays in `eval_env`, if its `eltype` is `Union{Missing, T}` where `T` is a concrete type, then
it tries to convert the array to `AbstractArray{T}`. If the conversion is not possible, it leaves
the array unchanged.
# Examples
```jldoctest; setup = :(using JuliaBUGS: concretize_eval_env)
julia> concretize_eval_env((a = Union{Missing,Int}[1, 2, 3],))
(a = [1, 2, 3],)
```
"""
function concretize_eval_env(eval_env::NamedTuple)
for k in keys(eval_env)
v = eval_env[k]
Expand Down Expand Up @@ -579,21 +603,7 @@ function bugs_eval(expr, env, dist_store)
return error("Unknown expression type: $expr of type $(typeof(expr))")
end

"""
evaluate(vn::VarName, env)
Retrieve the value of a possible variable identified by `vn` from `env`, return `nothing` if not found.
"""
function evaluate(vn::VarName, env)
sym = getsym(vn)
ret = nothing
try
ret = get(env[sym], getlens(vn))
catch _
end
return ismissing(ret) ? nothing : ret
end

# TODO: can't remove even with the `possible` fix in DynamicPPL, still seems to have eltype inference issue causing AD errors
# Resolves: setindex!!([1 2; 3 4], [2 3; 4 5], 1:2, 1:2) # returns 2×2 Matrix{Any}
# Alternatively, can overload BangBang.possible(
# ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg
Expand All @@ -608,39 +618,3 @@ function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...)
ys[I...] = v
return ys
end

# defines some default bijectors for link functions
# these are currently not in use, because we transform the expression by calling inverse functions
# on the RHS (in the case of logical assignment) or disallow the use of link functions (in the case of
# stochastic assignments)

struct LogisticBijector <: Bijectors.Bijector end

Bijectors.transform(::LogisticBijector, x::Real) = logistic(x)
Bijectors.transform(::Inverse{LogisticBijector}, x::Real) = logit(x)
Bijectors.logabsdet(::LogisticBijector, x::Real) = log(logistic(x)) + log(1 - logistic(x))

struct CExpExpBijector <: Bijectors.Bijector end

Bijectors.transform(::CExpExpBijector, x::Real) = icloglog(x)
Bijectors.transform(::Inverse{CExpExpBijector}, x::Real) = cloglog(x)
Bijectors.logabsdet(::CExpExpBijector, x::Real) = -log(cloglog(-x))

struct ExpBijector <: Bijectors.Bijector end

Bijectors.transform(::ExpBijector, x::Real) = exp(x)
Bijectors.transform(::Inverse{ExpBijector}, x::Real) = log(x)
Bijectors.logabsdet(::ExpBijector, x::Real) = x

struct PhiBijector <: Bijectors.Bijector end

Bijectors.transform(::PhiBijector, x::Real) = phi(x)
Bijectors.transform(::Inverse{PhiBijector}, x::Real) = probit(x)
Bijectors.logabsdet(::PhiBijector, x::Real) = -0.5 * (x^2 + log(2π))

link_function_to_bijector_mapping = Dict(
:logit => LogisticBijector(),
:cloglog => CExpExpBijector(),
:log => ExpBijector(),
:probit => PhiBijector(),
)

0 comments on commit 4f4772a

Please sign in to comment.