Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve evaluate function #148

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/BUGSPrimitives/BUGSPrimitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,49 @@ function _inv(m::AbstractMatrix)
return inv_L' * inv_L
end

const BUGS_FUNCTIONS = [
const BUGS_FUNCTIONS = (
:abs,
:cloglog,
:cexpexp,
:equals,
:exp,
:icloglog,
:ilogit,
:inprod,
:inverse,
:log,
:logdet,
:logfact,
:loggam,
:logit,
:logistic,
:mexp,
:max,
:mean,
:min,
:phi,
:probit,
:pow,
:sqrt,
:rank,
:ranked,
:round,
:sd,
:softplus,
:sort,
:_step,
:sum,
:trunc,
:sin,
:cos,
:tan,
:arcsin,
:arcsinh,
:arccos,
:arccosh,
:arctan,
:arctanh,
]
)

const BUGS_DISTRIBUTIONS = [
:dnorm,
Expand Down Expand Up @@ -162,4 +175,5 @@ export dnorm,
LeftTruncatedFlat,
RightTruncatedFlat,
TruncatedFlat

end
52 changes: 13 additions & 39 deletions src/BUGSPrimitives/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

Absolute value of `x`.
"""
function abs(x)
return Base.abs(x)
end
abs

"""
cloglog(x)
Expand Down Expand Up @@ -47,9 +45,7 @@ end

Exponential of ``x``.
"""
function exp(x)
return Base.Math.exp(x)
end
exp

"""
icloglog(x)
Expand Down Expand Up @@ -92,9 +88,7 @@ end

Natural logarithm of ``x``.
"""
function log(x)
return Base.Math.log(x)
end
log

"""
logdet(::AbstractMatrix)
Expand Down Expand Up @@ -163,9 +157,7 @@ end

Return the maximum value of the input arguments.
"""
function max(args...)
return Base.max(args...)
end
max

"""
mean(v::AbstractVector)
Expand All @@ -179,9 +171,7 @@ mean

Return the minimum value of the input arguments.
"""
function min(args...)
return Base.min(args...)
end
min

"""
phi(x)
Expand Down Expand Up @@ -215,9 +205,7 @@ end

Return the square root of ``x``.
"""
function sqrt(x)
return Base.Math.sqrt(x)
end
sqrt

"""
rank(v::AbstractVector, i::Integer)
Expand All @@ -242,9 +230,7 @@ end

Round ``x`` to the nearest Integereger.
"""
function round(x)
return Base.Math.round(x)
end
round

"""
sd(v::AbstractVector)
Expand All @@ -269,9 +255,7 @@ end

Return a sorted copy of the input vector `v`.
"""
function sort(v::AbstractVector)
return Base.sort(v)
end
sort

"""
_step(x)
Expand All @@ -287,27 +271,21 @@ end

Return the sum of the input arguments.
"""
function sum(args...)
return Base.sum(args...)
end
sum

"""
trunc(x)

Return the Integereger part of ``x``.
"""
function trunc(x)
return Base.Math.trunc(x)
end
trunc

"""
sin(x)

Return the sine of ``x``.
"""
function sin(x)
return Base.Math.sin(x)
end
sin

"""
arcsin(x)
Expand All @@ -332,9 +310,7 @@ end

Return the cosine of ``x``.
"""
function cos(x)
return Base.Math.cos(x)
end
cos

"""
arccos(x)
Expand All @@ -359,9 +335,7 @@ end

Return the tangent of ``x``.
"""
function tan(x)
return Base.Math.tan(x)
end
tan

"""
arctan(x)
Expand Down
113 changes: 62 additions & 51 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@

# Examples
```jldoctest
julia> find_variables_on_lhs(:(x[1, 2]), Dict())
julia> find_variables_on_lhs(:(x[1, 2]), NamedTuple())
x[1, 2]

julia> find_variables_on_lhs(:(x[1, 2:3]), Dict())
julia> find_variables_on_lhs(:(x[1, 2:3]), NamedTuple())
x[1, 2:3]
```
"""
Expand All @@ -142,85 +142,96 @@
end

"""
evaluate(var, env)
evaluate(expr, env)

Evaluate `var` in the environment `env`.
Evaluate `expr` in the environment `env`.

# Examples
```jldoctest
julia> evaluate(:(x[1]), Dict(:x => [1, 2, 3])) # array indexing is evaluated if possible
julia> evaluate(:(x[1]), (x = [1, 2, 3],)) # array indexing is evaluated if possible
1

julia> evaluate(:(x[1] + 1), Dict(:x => [1, 2, 3]))
julia> evaluate(:(x[1] + 1), (x = [1, 2, 3],))
2

julia> evaluate(:(x[1:2]), Dict()) |> Meta.show_sexpr # ranges are evaluated
julia> evaluate(:(x[1:2]), NamedTuple()) |> Meta.show_sexpr # ranges are evaluated
(:ref, :x, 1:2)

julia> evaluate(:(x[1:2]), Dict(:x => [1, 2, 3])) # ranges are evaluated
julia> evaluate(:(x[1:2]), (x = [1, 2, 3],))
2-element Vector{Int64}:
1
2

julia> evaluate(:(x[1:3]), Dict(:x => [1, 2, missing])) # when evaluate an array, if any element is missing, original expr is returned
julia> evaluate(:(x[1:3]), (x = [1, 2, missing],)) # when evaluate an array, if any element is missing, original expr is returned
:(x[1:3])

julia> evaluate(:(x[y[1] + 1] + 1), Dict()) # if a ref expr can't be evaluated, it's returned as is
julia> evaluate(:(x[y[1] + 1] + 1), NamedTuple()) # if a ref expr can't be evaluated, it's returned as is
:(x[y[1] + 1] + 1)

julia> evaluate(:(sum(x[:])), Dict(:x => [1, 2, 3])) # function calls are evaluated if possible
julia> evaluate(:(sum(x[:])), (x = [1, 2, 3],)) # function calls are evaluated if possible
6

julia> evaluate(:(f(1)), Dict()) # if a function call can't be evaluated, it's returned as is
julia> evaluate(:(f(1)), NamedTuple()) # if a function call can't be evaluated, it's returned as is
:(f(1))
"""
evaluate(var::Number, env) = var
evaluate(var::UnitRange, env) = var
evaluate(::Colon, env) = Colon()
function evaluate(var::Symbol, env)
var == :(:) && return Colon()
if haskey(env, var)
value = env[var]
if value === missing
return var
evaluate(expr::Number, env) = expr
evaluate(expr::UnitRange, env) = expr
evaluate(expr::Colon, env) = expr

Check warning on line 179 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L178-L179

Added lines #L178 - L179 were not covered by tests
function evaluate(expr::Symbol, env::NamedTuple{variable_names}) where {variable_names}
if expr == :(:)
return Colon()
else
if expr in variable_names
return env[expr] === missing ? expr : env[expr]
else
return value
return expr
end
else
return var
end
end
function evaluate(var::Expr, env)
if Meta.isexpr(var, :ref)
idxs = (ex -> evaluate(ex, env)).(var.args[2:end])
!isa(idxs, Array) && (idxs = [idxs])
if all(x -> x isa Number, idxs) && haskey(env, var.args[1])
for i in eachindex(idxs)
if !isa(idxs[i], Integer) && !isinteger(idxs[i])
error("Array indices must be integers or UnitRanges.")
function evaluate(expr::Expr, env::NamedTuple{variable_names}) where {variable_names}
if Meta.isexpr(expr, :ref)
var, indices... = expr.args
all_resolved = true
for i in eachindex(indices)
indices[i] = evaluate(indices[i], env)
if indices[i] isa Float64
indices[i] = Int(indices[i])
end
all_resolved = all_resolved && indices[i] isa Union{Int,UnitRange{Int},Colon}
end
if var in variable_names
if all_resolved
value = env[var][indices...]
if is_resolved(value)
return value
else
return Expr(:ref, var, indices...)
end
end
value = env[var.args[1]][Int.(idxs)...]
return ismissing(value) ? Expr(var.head, var.args[1], idxs...) : value
elseif all(x -> x isa Union{Number,UnitRange,Colon,Array}, idxs) &&
haskey(env, var.args[1])
value = getindex(env[var.args[1]], idxs...) # can use `view` here
!any(ismissing, value) && return value
end
return Expr(var.head, var.args[1], idxs...)
elseif var.args[1] ∈ BUGSPrimitives.BUGS_FUNCTIONS ||
var.args[1] ∈ (:+, :-, :*, :/, :^, :(:)) # function call
# elseif isdefined(JuliaBUGS, var.args[1])
f = var.args[1]
args = map(ex -> evaluate(ex, env), var.args[2:end])
if all(is_resolved, args)
return getfield(JuliaBUGS, f)(args...)
else
return Expr(var.head, f, args...)
return Expr(:ref, var, indices...)
end
else # don't try to eval the function, but try to simplify
args = map(ex -> evaluate(ex, env), var.args[2:end])
return Expr(var.head, var.args[1], args...)
elseif Meta.isexpr(expr, :call)
f, args... = expr.args
all_resolved = true
for i in eachindex(args)
args[i] = evaluate(args[i], env)
all_resolved = all_resolved && is_resolved(args[i])
end
if all_resolved
if f === :(:)
return UnitRange(Int(args[1]), Int(args[2]))
elseif f ∈ BUGSPrimitives.BUGS_FUNCTIONS ∪ (:+, :-, :*, :/, :^)
_f = getfield(BUGSPrimitives, f)
return _f(args...)
else
return Expr(:call, f, args...)
end
else
return Expr(:call, f, args...)
end
else
error("Unsupported expression: $var")

Check warning on line 234 in src/compiler_pass.jl

View check run for this annotation

Codecov / codecov/patch

src/compiler_pass.jl#L234

Added line #L234 was not covered by tests
end
end

Expand Down
Loading