Skip to content

Commit

Permalink
add a BayesNet implemented without MetaGraphsNext
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 3, 2024
1 parent b466e8e commit 7d07e62
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,6 @@ Only defined with `MCMCChains` extension.
"""
function gen_chains end

include("experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module ProbabilisticGraphicalModels

using BangBang
using Graphs
using Distributions

include("bayesnet.jl")

end
194 changes: 194 additions & 0 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
BayesianNetwork
A structure representing a Bayesian Network.
"""
struct BayesianNetwork{V,T,F}
graph::SimpleGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
names_to_ids::Dict{V,T}
"values of each variable in the network"
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future
"distributions of the stochastic variables"
distributions::Vector{Distribution}
"deterministic functions of the deterministic variables"
deterministic_functions::Vector{F}
"ids of the stochastic variables"
stochastic_ids::Vector{T}
"ids of the deterministic variables"
deterministic_ids::Vector{T}
is_stochastic::BitVector
is_observed::BitVector
end

function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleGraph{Int}(), # by default, vertex ids are integers
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Distribution[],
Any[],
Int[],
Int[],
BitVector(),
BitVector(),
)
end

"""
condition(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V}
Condition the Bayesian Network on the values of some variables. Return a new Bayesian Network with the conditioned graph.
"""
function condition(
bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any}
) where {V}
is_observed = copy(bn.is_observed)
values = copy(bn.values)
bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values)
return condition!(bn_new, conditioning_variables_and_values)
end

"""
condition!(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V}
Condition the Bayesian Network on the values of some variables. Mutating version of [`condition`](@ref).
"""
function condition!(
bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any}
) where {V}
for (name, value) in conditioning_variables_and_values
id = bn.names_to_ids[name]
if !bn.is_stochastic[id]
throw(ArgumentError("Variable $name is not stochastic, cannot condition on it"))
elseif bn.is_observed[id]
@warn "Variable $name is already observed, overwriting its value"
else
bn.is_observed[id] = true
end
bn.values[name] = value
end
return bn
end

function decondition(bn::BayesianNetwork{V}) where {V}
conditioned_variables_ids = findall(bn.is_observed)
return decondition(bn, bn.names[conditioned_variables_ids])
end

function decondition!(bn::BayesianNetwork{V}) where {V}
conditioned_variables_ids = findall(bn.is_observed)
return decondition!(bn, bn.names[conditioned_variables_ids])
end

function decondition(bn::BayesianNetwork{V}, variables::Vector{V}) where {V}
is_observed = copy(bn.is_observed)
values = copy(bn.values)
bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values)
return decondition!(bn_new, variables)
end

function decondition!(bn::BayesianNetwork{V}, deconditioning_variables::Vector{V}) where {V}
for name in deconditioning_variables
id = bn.names_to_ids[name]
if !bn.is_stochastic[id]
throw(
ArgumentError("Variable $name is not stochastic, cannot decondition on it")
)
elseif !bn.is_observed[id]
throw(ArgumentError("Variable $name is not observed, cannot decondition on it"))
end
bn.is_observed[id] = false
delete!(bn.values, name)
end
return bn
end

"""
add_stochastic_vertex!(bn::BayesianNetwork{V}, name::V, dist::Distribution, is_observed::Bool) where {V}
Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesian Network. Returns the id of the added vertex
if successful, 0 otherwise.
"""
function add_stochastic_vertex!(
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool
)::T where {V,T}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
push!(bn.distributions, dist)
push!(bn.is_stochastic, true)
push!(bn.is_observed, is_observed)
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.stochastic_ids, id)
return id
end

"""
add_deterministic_vertex!(bn::BayesianNetwork{V}, name::V, f::F) where {V,F}
Adds a deterministic vertex with name `name` and deterministic function `f` to the Bayesian Network. Returns the id of the added vertex
if successful, 0 otherwise.
"""
function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T where {T,V,F}
Graphs.add_vertex!(bn.graph) || return 0
id = nv(bn.graph)
push!(bn.deterministic_functions, f)
push!(bn.is_stochastic, false)
push!(bn.is_observed, false)
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.deterministic_ids, id)
return id
end

"""
add_edge!(bn::BayesianNetwork{V}, from::V, to::V) where {V}
Adds an edge between two vertices in the Bayesian Network. Returns true if successful, false otherwise.
"""
function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V}
from_id = bn.names_to_ids[from]
to_id = bn.names_to_ids[to]
return Graphs.add_edge!(bn.graph, from_id, to_id)
end

"""
ancestral_sampling(bn::BayesianNetwork{V}) where {V}
Perform ancestral sampling on a Bayesian network to generate one sample from the joint distribution.
Ancestral sampling works by:
1. Finding a topological ordering of the nodes
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions
"""
function ancestral_sampling(bn::BayesianNetwork{V}) where {V}
ordered_vertices = Graphs.topological_sort(bn.graph)

samples = Dict{V,Any}()

# TODO: Implement sampling logic

return samples
end

"""
is_conditionally_independent(bn::BayesianNetwork, X::V, Y::V[, Z::Vector{V}]) where {V}
Determines if two variables X and Y are conditionally independent given the conditioning information already known.
If Z is provided, the conditioning information in `bn` will be ignored.
"""
function is_conditionally_independent end

function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V}
# TODO: Implement
end

function is_conditionally_independent(
bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V}
) where {V}
# TODO: Implement
end
102 changes: 102 additions & 0 deletions test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Test
using Distributions
using Graphs
using JuliaBUGS.ProbabilisticGraphicalModels:
BayesianNetwork, add_stochastic_vertex!, add_deterministic_vertex!, add_edge!, condition, decondition

@testset "BayesianNetwork" begin
@testset "Adding vertices" begin
bn = BayesianNetwork{Symbol}()

# Test adding stochastic vertex
id1 = add_stochastic_vertex!(bn, :A, Normal(0, 1), false)
@test id1 == 1
@test length(bn.names) == 1
@test bn.names[1] == :A
@test bn.names_to_ids[:A] == 1
@test bn.is_stochastic[1] == true
@test bn.is_observed[1] == false
@test length(bn.stochastic_ids) == 1

# Test adding deterministic vertex
f(x) = x^2
id2 = add_deterministic_vertex!(bn, :B, f)
@test id2 == 2
@test length(bn.names) == 2
@test bn.names[2] == :B
@test bn.names_to_ids[:B] == 2
@test bn.is_stochastic[2] == false
@test bn.is_observed[2] == false
@test length(bn.deterministic_ids) == 1
end

@testset "Adding edges" begin
bn = BayesianNetwork{Symbol}()
add_stochastic_vertex!(bn, :A, Normal(0, 1), false)
add_stochastic_vertex!(bn, :B, Normal(0, 1), false)

add_edge!(bn, :A, :B)
@test has_edge(bn.graph, 1, 2)
end

@testset "conditioning and deconditioning" begin
bn = BayesianNetwork{Symbol}()

# Add some vertices
add_stochastic_vertex!(bn, :A, Normal(0, 1), false)
add_stochastic_vertex!(bn, :B, Normal(0, 1), false)
add_stochastic_vertex!(bn, :C, Normal(0, 1), false)

# Test conditioning
bn_cond = condition(bn, Dict(:A => 1.0))
@test bn_cond.is_observed[1] == true
@test bn_cond.values[:A] == 1.0
@test bn_cond.is_observed[2] == false
@test bn_cond.is_observed[3] == false

# Ensure original bn is not mutated
@test bn.is_observed[1] == false
@test !haskey(bn.values, :A)

# Test conditioning multiple variables
bn_cond2 = condition(bn_cond, Dict(:B => 2.0))
@test bn_cond2.is_observed[1] == true
@test bn_cond2.is_observed[2] == true
@test bn_cond2.values[:A] == 1.0
@test bn_cond2.values[:B] == 2.0

# Ensure bn_cond is not mutated
@test bn_cond.is_observed[2] == false
@test !haskey(bn_cond.values, :B)

# Test deconditioning
bn_decond = decondition(bn_cond2, [:A])
@test bn_decond.is_observed[1] == false
@test bn_decond.is_observed[2] == true
@test !haskey(bn_decond.values, :A)
@test bn_decond.values[:B] == 2.0

# Ensure bn_cond2 is not mutated
@test bn_cond2.is_observed[1] == true
@test bn_cond2.values[:A] == 1.0

# Test deconditioning all
bn_decond_all = decondition(bn_cond2)
@test all(.!bn_decond_all.is_observed)
@test all(values(bn_decond_all.values) .=== nothing)

# Ensure bn_cond2 is still not mutated
@test bn_cond2.is_observed[1] == true
@test bn_cond2.is_observed[2] == true
@test bn_cond2.values[:A] == 1.0
@test bn_cond2.values[:B] == 2.0
end

@testset "Simple ancestral sampling" begin

end

@testset "Bayes Ball" begin

end
end
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using ReverseDiff

AbstractMCMC.setprogress!(false)

const Tests = ("elementary", "compilation", "log_density", "gibbs", "mcmchains", "all")
const Tests = ("elementary", "compilation", "log_density", "gibbs", "mcmchains", "experimental", "all")

const test_group = get(ENV, "TEST_GROUP", "all")
if test_group Tests
Expand Down Expand Up @@ -67,3 +67,7 @@ end
if test_group == "mcmchains" || test_group == "all"
include("ext/mcmchains.jl")
end

if test_group == "experimental" || test_group == "all"
include("experimental/ProbabilisticGraphicalModels/bayesnet.jl")
end

0 comments on commit 7d07e62

Please sign in to comment.