-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add a BayesNet implemented without MetaGraphsNext
- Loading branch information
Showing
5 changed files
with
312 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module ProbabilisticGraphicalModels | ||
|
||
using BangBang | ||
using Graphs | ||
using Distributions | ||
|
||
include("bayesnet.jl") | ||
|
||
end |
194 changes: 194 additions & 0 deletions
194
src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
""" | ||
BayesianNetwork | ||
A structure representing a Bayesian Network. | ||
""" | ||
struct BayesianNetwork{V,T,F} | ||
graph::SimpleGraph{T} | ||
"names of the variables in the network" | ||
names::Vector{V} | ||
"mapping from variable names to ids" | ||
names_to_ids::Dict{V,T} | ||
"values of each variable in the network" | ||
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future | ||
"distributions of the stochastic variables" | ||
distributions::Vector{Distribution} | ||
"deterministic functions of the deterministic variables" | ||
deterministic_functions::Vector{F} | ||
"ids of the stochastic variables" | ||
stochastic_ids::Vector{T} | ||
"ids of the deterministic variables" | ||
deterministic_ids::Vector{T} | ||
is_stochastic::BitVector | ||
is_observed::BitVector | ||
end | ||
|
||
function BayesianNetwork{V}() where {V} | ||
return BayesianNetwork( | ||
SimpleGraph{Int}(), # by default, vertex ids are integers | ||
V[], | ||
Dict{V,Int}(), | ||
Dict{V,Any}(), | ||
Distribution[], | ||
Any[], | ||
Int[], | ||
Int[], | ||
BitVector(), | ||
BitVector(), | ||
) | ||
end | ||
|
||
""" | ||
condition(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V} | ||
Condition the Bayesian Network on the values of some variables. Return a new Bayesian Network with the conditioned graph. | ||
""" | ||
function condition( | ||
bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any} | ||
) where {V} | ||
is_observed = copy(bn.is_observed) | ||
values = copy(bn.values) | ||
bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values) | ||
return condition!(bn_new, conditioning_variables_and_values) | ||
end | ||
|
||
""" | ||
condition!(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V} | ||
Condition the Bayesian Network on the values of some variables. Mutating version of [`condition`](@ref). | ||
""" | ||
function condition!( | ||
bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any} | ||
) where {V} | ||
for (name, value) in conditioning_variables_and_values | ||
id = bn.names_to_ids[name] | ||
if !bn.is_stochastic[id] | ||
throw(ArgumentError("Variable $name is not stochastic, cannot condition on it")) | ||
elseif bn.is_observed[id] | ||
@warn "Variable $name is already observed, overwriting its value" | ||
else | ||
bn.is_observed[id] = true | ||
end | ||
bn.values[name] = value | ||
end | ||
return bn | ||
end | ||
|
||
function decondition(bn::BayesianNetwork{V}) where {V} | ||
conditioned_variables_ids = findall(bn.is_observed) | ||
return decondition(bn, bn.names[conditioned_variables_ids]) | ||
end | ||
|
||
function decondition!(bn::BayesianNetwork{V}) where {V} | ||
conditioned_variables_ids = findall(bn.is_observed) | ||
return decondition!(bn, bn.names[conditioned_variables_ids]) | ||
end | ||
|
||
function decondition(bn::BayesianNetwork{V}, variables::Vector{V}) where {V} | ||
is_observed = copy(bn.is_observed) | ||
values = copy(bn.values) | ||
bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values) | ||
return decondition!(bn_new, variables) | ||
end | ||
|
||
function decondition!(bn::BayesianNetwork{V}, deconditioning_variables::Vector{V}) where {V} | ||
for name in deconditioning_variables | ||
id = bn.names_to_ids[name] | ||
if !bn.is_stochastic[id] | ||
throw( | ||
ArgumentError("Variable $name is not stochastic, cannot decondition on it") | ||
) | ||
elseif !bn.is_observed[id] | ||
throw(ArgumentError("Variable $name is not observed, cannot decondition on it")) | ||
end | ||
bn.is_observed[id] = false | ||
delete!(bn.values, name) | ||
end | ||
return bn | ||
end | ||
|
||
""" | ||
add_stochastic_vertex!(bn::BayesianNetwork{V}, name::V, dist::Distribution, is_observed::Bool) where {V} | ||
Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesian Network. Returns the id of the added vertex | ||
if successful, 0 otherwise. | ||
""" | ||
function add_stochastic_vertex!( | ||
bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool | ||
)::T where {V,T} | ||
Graphs.add_vertex!(bn.graph) || return 0 | ||
id = nv(bn.graph) | ||
push!(bn.distributions, dist) | ||
push!(bn.is_stochastic, true) | ||
push!(bn.is_observed, is_observed) | ||
push!(bn.names, name) | ||
bn.names_to_ids[name] = id | ||
push!(bn.stochastic_ids, id) | ||
return id | ||
end | ||
|
||
""" | ||
add_deterministic_vertex!(bn::BayesianNetwork{V}, name::V, f::F) where {V,F} | ||
Adds a deterministic vertex with name `name` and deterministic function `f` to the Bayesian Network. Returns the id of the added vertex | ||
if successful, 0 otherwise. | ||
""" | ||
function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T where {T,V,F} | ||
Graphs.add_vertex!(bn.graph) || return 0 | ||
id = nv(bn.graph) | ||
push!(bn.deterministic_functions, f) | ||
push!(bn.is_stochastic, false) | ||
push!(bn.is_observed, false) | ||
push!(bn.names, name) | ||
bn.names_to_ids[name] = id | ||
push!(bn.deterministic_ids, id) | ||
return id | ||
end | ||
|
||
""" | ||
add_edge!(bn::BayesianNetwork{V}, from::V, to::V) where {V} | ||
Adds an edge between two vertices in the Bayesian Network. Returns true if successful, false otherwise. | ||
""" | ||
function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V} | ||
from_id = bn.names_to_ids[from] | ||
to_id = bn.names_to_ids[to] | ||
return Graphs.add_edge!(bn.graph, from_id, to_id) | ||
end | ||
|
||
""" | ||
ancestral_sampling(bn::BayesianNetwork{V}) where {V} | ||
Perform ancestral sampling on a Bayesian network to generate one sample from the joint distribution. | ||
Ancestral sampling works by: | ||
1. Finding a topological ordering of the nodes | ||
2. Sampling from each node in order, using the already-sampled parent values for conditional distributions | ||
""" | ||
function ancestral_sampling(bn::BayesianNetwork{V}) where {V} | ||
ordered_vertices = Graphs.topological_sort(bn.graph) | ||
|
||
samples = Dict{V,Any}() | ||
|
||
# TODO: Implement sampling logic | ||
|
||
return samples | ||
end | ||
|
||
""" | ||
is_conditionally_independent(bn::BayesianNetwork, X::V, Y::V[, Z::Vector{V}]) where {V} | ||
Determines if two variables X and Y are conditionally independent given the conditioning information already known. | ||
If Z is provided, the conditioning information in `bn` will be ignored. | ||
""" | ||
function is_conditionally_independent end | ||
|
||
function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V} | ||
# TODO: Implement | ||
end | ||
|
||
function is_conditionally_independent( | ||
bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V} | ||
) where {V} | ||
# TODO: Implement | ||
end |
102 changes: 102 additions & 0 deletions
102
test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
using Test | ||
using Distributions | ||
using Graphs | ||
using JuliaBUGS.ProbabilisticGraphicalModels: | ||
BayesianNetwork, add_stochastic_vertex!, add_deterministic_vertex!, add_edge!, condition, decondition | ||
|
||
@testset "BayesianNetwork" begin | ||
@testset "Adding vertices" begin | ||
bn = BayesianNetwork{Symbol}() | ||
|
||
# Test adding stochastic vertex | ||
id1 = add_stochastic_vertex!(bn, :A, Normal(0, 1), false) | ||
@test id1 == 1 | ||
@test length(bn.names) == 1 | ||
@test bn.names[1] == :A | ||
@test bn.names_to_ids[:A] == 1 | ||
@test bn.is_stochastic[1] == true | ||
@test bn.is_observed[1] == false | ||
@test length(bn.stochastic_ids) == 1 | ||
|
||
# Test adding deterministic vertex | ||
f(x) = x^2 | ||
id2 = add_deterministic_vertex!(bn, :B, f) | ||
@test id2 == 2 | ||
@test length(bn.names) == 2 | ||
@test bn.names[2] == :B | ||
@test bn.names_to_ids[:B] == 2 | ||
@test bn.is_stochastic[2] == false | ||
@test bn.is_observed[2] == false | ||
@test length(bn.deterministic_ids) == 1 | ||
end | ||
|
||
@testset "Adding edges" begin | ||
bn = BayesianNetwork{Symbol}() | ||
add_stochastic_vertex!(bn, :A, Normal(0, 1), false) | ||
add_stochastic_vertex!(bn, :B, Normal(0, 1), false) | ||
|
||
add_edge!(bn, :A, :B) | ||
@test has_edge(bn.graph, 1, 2) | ||
end | ||
|
||
@testset "conditioning and deconditioning" begin | ||
bn = BayesianNetwork{Symbol}() | ||
|
||
# Add some vertices | ||
add_stochastic_vertex!(bn, :A, Normal(0, 1), false) | ||
add_stochastic_vertex!(bn, :B, Normal(0, 1), false) | ||
add_stochastic_vertex!(bn, :C, Normal(0, 1), false) | ||
|
||
# Test conditioning | ||
bn_cond = condition(bn, Dict(:A => 1.0)) | ||
@test bn_cond.is_observed[1] == true | ||
@test bn_cond.values[:A] == 1.0 | ||
@test bn_cond.is_observed[2] == false | ||
@test bn_cond.is_observed[3] == false | ||
|
||
# Ensure original bn is not mutated | ||
@test bn.is_observed[1] == false | ||
@test !haskey(bn.values, :A) | ||
|
||
# Test conditioning multiple variables | ||
bn_cond2 = condition(bn_cond, Dict(:B => 2.0)) | ||
@test bn_cond2.is_observed[1] == true | ||
@test bn_cond2.is_observed[2] == true | ||
@test bn_cond2.values[:A] == 1.0 | ||
@test bn_cond2.values[:B] == 2.0 | ||
|
||
# Ensure bn_cond is not mutated | ||
@test bn_cond.is_observed[2] == false | ||
@test !haskey(bn_cond.values, :B) | ||
|
||
# Test deconditioning | ||
bn_decond = decondition(bn_cond2, [:A]) | ||
@test bn_decond.is_observed[1] == false | ||
@test bn_decond.is_observed[2] == true | ||
@test !haskey(bn_decond.values, :A) | ||
@test bn_decond.values[:B] == 2.0 | ||
|
||
# Ensure bn_cond2 is not mutated | ||
@test bn_cond2.is_observed[1] == true | ||
@test bn_cond2.values[:A] == 1.0 | ||
|
||
# Test deconditioning all | ||
bn_decond_all = decondition(bn_cond2) | ||
@test all(.!bn_decond_all.is_observed) | ||
@test all(values(bn_decond_all.values) .=== nothing) | ||
|
||
# Ensure bn_cond2 is still not mutated | ||
@test bn_cond2.is_observed[1] == true | ||
@test bn_cond2.is_observed[2] == true | ||
@test bn_cond2.values[:A] == 1.0 | ||
@test bn_cond2.values[:B] == 2.0 | ||
end | ||
|
||
@testset "Simple ancestral sampling" begin | ||
|
||
end | ||
|
||
@testset "Bayes Ball" begin | ||
|
||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters