From 7d07e627d14e28fa5a73c28b45caa69eb4c945fe Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Sun, 3 Nov 2024 12:46:55 +0000 Subject: [PATCH] add a BayesNet implemented without MetaGraphsNext --- src/JuliaBUGS.jl | 2 + .../ProbabilisticGraphicalModels.jl | 9 + .../ProbabilisticGraphicalModels/bayesnet.jl | 194 ++++++++++++++++++ .../ProbabilisticGraphicalModels/bayesnet.jl | 102 +++++++++ test/runtests.jl | 6 +- 5 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl create mode 100644 src/experimental/ProbabilisticGraphicalModels/bayesnet.jl create mode 100644 test/experimental/ProbabilisticGraphicalModels/bayesnet.jl diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 8071b5362..db888c352 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -251,4 +251,6 @@ Only defined with `MCMCChains` extension. """ function gen_chains end +include("experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl") + end diff --git a/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl b/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl new file mode 100644 index 000000000..4785bf151 --- /dev/null +++ b/src/experimental/ProbabilisticGraphicalModels/ProbabilisticGraphicalModels.jl @@ -0,0 +1,9 @@ +module ProbabilisticGraphicalModels + +using BangBang +using Graphs +using Distributions + +include("bayesnet.jl") + +end diff --git a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl new file mode 100644 index 000000000..063a17a77 --- /dev/null +++ b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -0,0 +1,194 @@ +""" + BayesianNetwork + +A structure representing a Bayesian Network. +""" +struct BayesianNetwork{V,T,F} + graph::SimpleGraph{T} + "names of the variables in the network" + names::Vector{V} + "mapping from variable names to ids" + names_to_ids::Dict{V,T} + "values of each variable in the network" + values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future + "distributions of the stochastic variables" + distributions::Vector{Distribution} + "deterministic functions of the deterministic variables" + deterministic_functions::Vector{F} + "ids of the stochastic variables" + stochastic_ids::Vector{T} + "ids of the deterministic variables" + deterministic_ids::Vector{T} + is_stochastic::BitVector + is_observed::BitVector +end + +function BayesianNetwork{V}() where {V} + return BayesianNetwork( + SimpleGraph{Int}(), # by default, vertex ids are integers + V[], + Dict{V,Int}(), + Dict{V,Any}(), + Distribution[], + Any[], + Int[], + Int[], + BitVector(), + BitVector(), + ) +end + +""" + condition(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V} + +Condition the Bayesian Network on the values of some variables. Return a new Bayesian Network with the conditioned graph. +""" +function condition( + bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any} +) where {V} + is_observed = copy(bn.is_observed) + values = copy(bn.values) + bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values) + return condition!(bn_new, conditioning_variables_and_values) +end + +""" + condition!(bn::BayesianNetwork{V}, values::Dict{V,Any}) where {V} + +Condition the Bayesian Network on the values of some variables. Mutating version of [`condition`](@ref). +""" +function condition!( + bn::BayesianNetwork{V}, conditioning_variables_and_values::Dict{V,<:Any} +) where {V} + for (name, value) in conditioning_variables_and_values + id = bn.names_to_ids[name] + if !bn.is_stochastic[id] + throw(ArgumentError("Variable $name is not stochastic, cannot condition on it")) + elseif bn.is_observed[id] + @warn "Variable $name is already observed, overwriting its value" + else + bn.is_observed[id] = true + end + bn.values[name] = value + end + return bn +end + +function decondition(bn::BayesianNetwork{V}) where {V} + conditioned_variables_ids = findall(bn.is_observed) + return decondition(bn, bn.names[conditioned_variables_ids]) +end + +function decondition!(bn::BayesianNetwork{V}) where {V} + conditioned_variables_ids = findall(bn.is_observed) + return decondition!(bn, bn.names[conditioned_variables_ids]) +end + +function decondition(bn::BayesianNetwork{V}, variables::Vector{V}) where {V} + is_observed = copy(bn.is_observed) + values = copy(bn.values) + bn_new = BangBang.setproperties!!(bn; is_observed=is_observed, values=values) + return decondition!(bn_new, variables) +end + +function decondition!(bn::BayesianNetwork{V}, deconditioning_variables::Vector{V}) where {V} + for name in deconditioning_variables + id = bn.names_to_ids[name] + if !bn.is_stochastic[id] + throw( + ArgumentError("Variable $name is not stochastic, cannot decondition on it") + ) + elseif !bn.is_observed[id] + throw(ArgumentError("Variable $name is not observed, cannot decondition on it")) + end + bn.is_observed[id] = false + delete!(bn.values, name) + end + return bn +end + +""" + add_stochastic_vertex!(bn::BayesianNetwork{V}, name::V, dist::Distribution, is_observed::Bool) where {V} + +Adds a stochastic vertex with name `name` and distribution `dist` to the Bayesian Network. Returns the id of the added vertex +if successful, 0 otherwise. +""" +function add_stochastic_vertex!( + bn::BayesianNetwork{V,T}, name::V, dist::Distribution, is_observed::Bool +)::T where {V,T} + Graphs.add_vertex!(bn.graph) || return 0 + id = nv(bn.graph) + push!(bn.distributions, dist) + push!(bn.is_stochastic, true) + push!(bn.is_observed, is_observed) + push!(bn.names, name) + bn.names_to_ids[name] = id + push!(bn.stochastic_ids, id) + return id +end + +""" + add_deterministic_vertex!(bn::BayesianNetwork{V}, name::V, f::F) where {V,F} + +Adds a deterministic vertex with name `name` and deterministic function `f` to the Bayesian Network. Returns the id of the added vertex +if successful, 0 otherwise. +""" +function add_deterministic_vertex!(bn::BayesianNetwork{V,T}, name::V, f::F)::T where {T,V,F} + Graphs.add_vertex!(bn.graph) || return 0 + id = nv(bn.graph) + push!(bn.deterministic_functions, f) + push!(bn.is_stochastic, false) + push!(bn.is_observed, false) + push!(bn.names, name) + bn.names_to_ids[name] = id + push!(bn.deterministic_ids, id) + return id +end + +""" + add_edge!(bn::BayesianNetwork{V}, from::V, to::V) where {V} + +Adds an edge between two vertices in the Bayesian Network. Returns true if successful, false otherwise. +""" +function add_edge!(bn::BayesianNetwork{V,T}, from::V, to::V)::Bool where {T,V} + from_id = bn.names_to_ids[from] + to_id = bn.names_to_ids[to] + return Graphs.add_edge!(bn.graph, from_id, to_id) +end + +""" + ancestral_sampling(bn::BayesianNetwork{V}) where {V} + +Perform ancestral sampling on a Bayesian network to generate one sample from the joint distribution. + +Ancestral sampling works by: +1. Finding a topological ordering of the nodes +2. Sampling from each node in order, using the already-sampled parent values for conditional distributions +""" +function ancestral_sampling(bn::BayesianNetwork{V}) where {V} + ordered_vertices = Graphs.topological_sort(bn.graph) + + samples = Dict{V,Any}() + + # TODO: Implement sampling logic + + return samples +end + +""" + is_conditionally_independent(bn::BayesianNetwork, X::V, Y::V[, Z::Vector{V}]) where {V} + +Determines if two variables X and Y are conditionally independent given the conditioning information already known. +If Z is provided, the conditioning information in `bn` will be ignored. +""" +function is_conditionally_independent end + +function is_conditionally_independent(bn::BayesianNetwork{V}, X::V, Y::V) where {V} + # TODO: Implement +end + +function is_conditionally_independent( + bn::BayesianNetwork{V}, X::V, Y::V, Z::Vector{V} +) where {V} + # TODO: Implement +end diff --git a/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl new file mode 100644 index 000000000..a4c9ed3aa --- /dev/null +++ b/test/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -0,0 +1,102 @@ +using Test +using Distributions +using Graphs +using JuliaBUGS.ProbabilisticGraphicalModels: + BayesianNetwork, add_stochastic_vertex!, add_deterministic_vertex!, add_edge!, condition, decondition + +@testset "BayesianNetwork" begin + @testset "Adding vertices" begin + bn = BayesianNetwork{Symbol}() + + # Test adding stochastic vertex + id1 = add_stochastic_vertex!(bn, :A, Normal(0, 1), false) + @test id1 == 1 + @test length(bn.names) == 1 + @test bn.names[1] == :A + @test bn.names_to_ids[:A] == 1 + @test bn.is_stochastic[1] == true + @test bn.is_observed[1] == false + @test length(bn.stochastic_ids) == 1 + + # Test adding deterministic vertex + f(x) = x^2 + id2 = add_deterministic_vertex!(bn, :B, f) + @test id2 == 2 + @test length(bn.names) == 2 + @test bn.names[2] == :B + @test bn.names_to_ids[:B] == 2 + @test bn.is_stochastic[2] == false + @test bn.is_observed[2] == false + @test length(bn.deterministic_ids) == 1 + end + + @testset "Adding edges" begin + bn = BayesianNetwork{Symbol}() + add_stochastic_vertex!(bn, :A, Normal(0, 1), false) + add_stochastic_vertex!(bn, :B, Normal(0, 1), false) + + add_edge!(bn, :A, :B) + @test has_edge(bn.graph, 1, 2) + end + + @testset "conditioning and deconditioning" begin + bn = BayesianNetwork{Symbol}() + + # Add some vertices + add_stochastic_vertex!(bn, :A, Normal(0, 1), false) + add_stochastic_vertex!(bn, :B, Normal(0, 1), false) + add_stochastic_vertex!(bn, :C, Normal(0, 1), false) + + # Test conditioning + bn_cond = condition(bn, Dict(:A => 1.0)) + @test bn_cond.is_observed[1] == true + @test bn_cond.values[:A] == 1.0 + @test bn_cond.is_observed[2] == false + @test bn_cond.is_observed[3] == false + + # Ensure original bn is not mutated + @test bn.is_observed[1] == false + @test !haskey(bn.values, :A) + + # Test conditioning multiple variables + bn_cond2 = condition(bn_cond, Dict(:B => 2.0)) + @test bn_cond2.is_observed[1] == true + @test bn_cond2.is_observed[2] == true + @test bn_cond2.values[:A] == 1.0 + @test bn_cond2.values[:B] == 2.0 + + # Ensure bn_cond is not mutated + @test bn_cond.is_observed[2] == false + @test !haskey(bn_cond.values, :B) + + # Test deconditioning + bn_decond = decondition(bn_cond2, [:A]) + @test bn_decond.is_observed[1] == false + @test bn_decond.is_observed[2] == true + @test !haskey(bn_decond.values, :A) + @test bn_decond.values[:B] == 2.0 + + # Ensure bn_cond2 is not mutated + @test bn_cond2.is_observed[1] == true + @test bn_cond2.values[:A] == 1.0 + + # Test deconditioning all + bn_decond_all = decondition(bn_cond2) + @test all(.!bn_decond_all.is_observed) + @test all(values(bn_decond_all.values) .=== nothing) + + # Ensure bn_cond2 is still not mutated + @test bn_cond2.is_observed[1] == true + @test bn_cond2.is_observed[2] == true + @test bn_cond2.values[:A] == 1.0 + @test bn_cond2.values[:B] == 2.0 + end + + @testset "Simple ancestral sampling" begin + + end + + @testset "Bayes Ball" begin + + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d7fb7e0ce..d94768240 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,7 +24,7 @@ using ReverseDiff AbstractMCMC.setprogress!(false) -const Tests = ("elementary", "compilation", "log_density", "gibbs", "mcmchains", "all") +const Tests = ("elementary", "compilation", "log_density", "gibbs", "mcmchains", "experimental", "all") const test_group = get(ENV, "TEST_GROUP", "all") if test_group ∉ Tests @@ -67,3 +67,7 @@ end if test_group == "mcmchains" || test_group == "all" include("ext/mcmchains.jl") end + +if test_group == "experimental" || test_group == "all" + include("experimental/ProbabilisticGraphicalModels/bayesnet.jl") +end