Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support trans-dimensional bijectors, flatten MarkovBlanketCoveredModel #88

Merged
merged 18 commits into from
Sep 13, 2023
Merged
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -44,8 +45,8 @@ AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
Distributions = "0.23.8, 0.24, 0.25"
DynamicPPL = "0.22, 0.23"
Documenter = "0.27"
DynamicPPL = "0.22, 0.23"
Graphs = "1.4.1"
InverseFunctions = "0.1"
JuliaSyntax = "0.4"
Expand All @@ -68,4 +69,4 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractMCMC", "AdvancedHMC", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
test = ["AbstractMCMC", "AdvancedHMC", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
1 change: 1 addition & 0 deletions src/BUGSPrimitives/BUGSPrimitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module BUGSPrimitives
using Distributions
using LinearAlgebra
using LogExpFunctions
using PDMats
using Random
using SpecialFunctions
using Statistics
Expand Down
16 changes: 12 additions & 4 deletions src/BUGSPrimitives/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ end
"""
dmt(μ::Vector, T::Matrix, k)

Return a [Multivariate T](https://juliastats.org/Distributions.jl/latest/multivariate/#Distributions.MvTDist)
Return a [Multivariate T](https://juliastats.org/Distributions.jl/latest/matrix/#Distributions.MatrixTDist)
distribution object with mean vector `μ`, precision matrix `T`, and `k` degrees of freedom.

The mathematical form of the PDF for a Multivariate T distribution in the BUGS family of softwares is given by:
Expand All @@ -418,17 +418,25 @@ end
"""
dwish(R::Matrix, k)

Return a [Wishart](https://juliastats.org/Distributions.jl/latest/multivariate/#Distributions.Wishart)
Return a [Wishart](https://juliastats.org/Distributions.jl/latest/matrix/#Distributions.Wishart)
distribution object with `k` degrees of freedom and scale matrix `R^(-1)`.

The mathematical form of the PDF for a Wishart distribution in the BUGS family of softwares is given by:

```math
p(X|R,k) = |X|^{(k-p-1)/2} e^{-1/2 tr(RX)} / (2^{kp/2} |R|^{k/2} Γ_p(k/2))
p(X|R,k) = |X|^{(k-p-1)/2} e^{-(1/2) tr(RX)} / (2^{kp/2} |R|^{k/2} Γ_p(k/2))
```
where `p` is the dimension of `X`, and `p` should be less or equal to `k`.

This is the definition as in `The BUGS Book` (Lunn, D. J., Jackson, C., Best, N.,
Thomas, A., & Spiegelhalter, D. (2013). The BUGS Book: A Practical Introduction to Bayesian
Analysis. CRC Press.), which is different from OpenBUGS' definition of the pdf of the Wishart distribution:
```math
|R|^{k/2} |x|^{(k-p-1)/2} \\exp\\left(-\\frac{1}{2} \\text{Tr}(Rx)\\right)
```
"""
function dwish(R::Matrix, k)
return Wishart(k, inv(R))
return Wishart(k, PDMat(inv(R), cholesky(R)))
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using AbstractPPL
using BangBang
using Bijectors
using Distributions
using DynamicPPL
using Graphs
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
Expand All @@ -13,10 +12,11 @@ using Random
using Setfield
using UnPack

using DynamicPPL: DynamicPPL, SimpleVarInfo

import Base: ==, hash, Symbol, size
import Distributions: truncated
import AbstractPPL: AbstractContext, evaluate!!
import DynamicPPL: settrans!!

export @bugs
export compile
Expand Down
34 changes: 28 additions & 6 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,13 +505,18 @@ julia> evaluate_and_track_dependencies(:(getindex(x[1:2, 1:3], a, b)), Dict(:x =

julia> evaluate_and_track_dependencies(:(getindex(x[1:2, 1:3], a, b)), Dict(:x => [1 2 missing; 4 5 6]))
(:(getindex(Union{Missing, Int64}[1 2 missing; 4 5 6], a, b)), Set(Any[:a, :b, (:x, (1, 3))]), Set(Any[:a, :b, (:x, ())]))

julia> evaluate_and_track_dependencies(:x, Dict(:x => [1 2])) # array variables must be explicitly indexed
ERROR: AssertionError: Array indexing in BUGS must be explicit. However, `x` is accessed as a scalar.
[...]
```
"""
evaluate_and_track_dependencies(var::Number, env) = var, Set(), Set()
evaluate_and_track_dependencies(var::UnitRange, env) = var, Set(), Set()
function evaluate_and_track_dependencies(var::Symbol, env)
value = haskey(env, var) ? env[var] : var
@assert !ismissing(value) "Scalar variables in data can't be missing, but $var given as missing"
@assert value isa Union{Real,Symbol} "Array indexing in BUGS must be explicit. However, `$var` is accessed as a scalar."
return value, Set(), Set()
end
function evaluate_and_track_dependencies(var::Expr, env)
Expand Down Expand Up @@ -622,17 +627,34 @@ function replace_constants_in_expr(x, env)
end

_replace_constants_in_expr(x::Number, env) = x
_replace_constants_in_expr(x::Symbol, env) = get(env, x, x)
function _replace_constants_in_expr(x, env)
function _replace_constants_in_expr(x::Symbol, env)
if haskey(env, x)
if env[x] isa Number # only plug in scalar variables
return env[x]
else # if it's an array, raise error because array indexing should be explicit
error("$x")
end
end
return x
end
function _replace_constants_in_expr(x::Expr, env)
if Meta.isexpr(x, :ref) && all(x -> x isa Number, x.args[2:end])
if haskey(env, x.args[1])
val = env[x.args[1]][try_cast_to_int.(x.args[2:end])...]
x = ismissing(val) ? x : val
return ismissing(val) ? x : val
end
elseif !isa(x, Symbol) && !isa(x, Number)
x = deepcopy(x)
else # don't try to eval the function, but try to simplify
x = deepcopy(x) # because we are mutating the args
for i in 2:length(x.args)
x.args[i] = _replace_constants_in_expr(x.args[i], env)
try
x.args[i] = _replace_constants_in_expr(x.args[i], env)
catch e
rethrow(
ErrorException(
"Array indexing in BUGS must be explicit. However, `$(e.msg)` is accessed as a scalar.",
),
)
end
end
end
return x
Expand Down
6 changes: 3 additions & 3 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ end
"""
markov_blanket(g::BUGSModel, v)

Find the Markov blanket of `v` in `g`. `v` can be a single `VarName` or a vector of `VarName`.
Find the Markov blanket of variable(s) `v` in graph `g`. `v` can be a single `VarName` or a vector/tuple of `VarName`.
The Markov Blanket of a variable is the set of variables that shield the variable from the rest of the
network. Effectively, the Markov blanket of a variable is the set of its parents, its children, and
its children's other parents (reference: https://en.wikipedia.org/wiki/Markov_blanket).
Expand All @@ -213,7 +213,7 @@ In the case of vector, the Markov Blanket is the union of the Markov Blankets of
minus the variables themselves (reference: Liu, X.-Q., & Liu, X.-S. (2018). Markov Blanket and Markov
Boundary of Multiple Variables. Journal of Machine Learning Research, 19(43), 1–50.)
"""
function markov_blanket(g, v::VarName)
function markov_blanket(g::BUGSGraph, v::VarName)
parents = stochastic_inneighbors(g, v)
children = stochastic_outneighbors(g, v)
co_parents = VarName[]
Expand All @@ -224,7 +224,7 @@ function markov_blanket(g, v::VarName)
return [x for x in blanket if x != v]
end

function markov_blanket(g, v)
function markov_blanket(g::BUGSGraph, v)
blanket = VarName[]
for vn in v
blanket = vcat(blanket, markov_blanket(g, vn))
Expand Down
10 changes: 7 additions & 3 deletions src/logdensityproblems.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
function LogDensityProblems.logdensity(model::AbstractBUGSModel, x::AbstractArray)
vi = evaluate!!(model, LogDensityContext(), x)
return DynamicPPL.getlogp(vi)
vi, logp = evaluate!!(model, LogDensityContext(), x)
return logp
end

function LogDensityProblems.dimension(model::AbstractBUGSModel)
return model.param_length
return if model.if_transform
model.param_length[2]
else
model.param_length[1]
end
end

function LogDensityProblems.capabilities(::AbstractBUGSModel)
Expand Down
Loading
Loading