diff --git a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl index e09a23a0..93df0c62 100644 --- a/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl +++ b/src/experimental/ProbabilisticGraphicalModels/bayesnet.jl @@ -4,41 +4,33 @@ A structure representing a Bayesian Network. """ -struct BayesianNetwork{V,T,F} +# First, modify the BayesianNetwork struct definition +struct BayesianNetwork{V,T} graph::SimpleDiGraph{T} - "names of the variables in the network" names::Vector{V} - "mapping from variable names to ids" names_to_ids::Dict{V,T} - "values of each variable in the network" - values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future - "distributions of the stochastic variables" - # A distribution can be either: - # - A fixed distribution (like Uniform(0,1)) - # - A function that takes parent values and returns a distribution - distributions::Vector{Union{Distribution,Function}} - "deterministic functions of the deterministic variables" - deterministic_functions::Vector{F} - "ids of the stochastic variables" - stochastic_ids::Vector{T} - "ids of the deterministic variables" - deterministic_ids::Vector{T} - is_stochastic::BitVector + values::Dict{V,Any} + distributions::Vector{Union{Distribution,Function}} is_observed::BitVector + is_stochastic::BitVector + stochastic_ids::Vector{Int} + deterministic_ids::Vector{Int} + deterministic_functions::Vector{Function} end +# Then, modify the constructor to match exactly function BayesianNetwork{V}() where {V} - return BayesianNetwork( - SimpleDiGraph{Int}(), # by default, vertex ids are integers - V[], - Dict{V,Int}(), - Dict{V,Any}(), - Distribution[], - Any[], - Int[], - Int[], - BitVector(), - BitVector(), + return BayesianNetwork{V,Int}( + SimpleDiGraph{Int}(), # graph + V[], # names + Dict{V,Int}(), # names_to_ids + Dict{V,Any}(), # values + Union{Distribution,Function}[], # distributions + BitVector(), # is_observed + BitVector(), # is_stochastic + Int[], # stochastic_ids + Int[], # deterministic_ids + Function[] # deterministic_functions - Added this ) end @@ -127,8 +119,10 @@ function add_stochastic_vertex!( id = nv(bn.graph) push!(bn.distributions, dist) push!(bn.is_observed, is_observed) + push!(bn.is_stochastic, true) push!(bn.names, name) bn.names_to_ids[name] = id + push!(bn.stochastic_ids, id) return id end @@ -358,12 +352,25 @@ function marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V} # Start recursive elimination return eliminate_variables(bn, ordered_vertices, query_id, Dict{V,Any}()) end +# Helper functions to evaluate distributions +function evaluate_distribution(dist::Distribution, _) + return dist +end -""" - eliminate_variables(bn, ordered_vertices, query_id, assignments) +function evaluate_distribution(dist_func::Function, parent_values) + # Skip evaluation if any parent value is nothing + if any(isnothing, parent_values) + return nothing + end + + # If there's only one parent value, pass it directly instead of splatting + if length(parent_values) == 1 + return dist_func(parent_values[1]) + else + return dist_func(parent_values...) + end +end -Helper function for variable elimination algorithm. -""" function eliminate_variables( bn::BayesianNetwork{V}, ordered_vertices::Vector{Int}, @@ -373,34 +380,54 @@ function eliminate_variables( # Base case: reached query variable if isempty(ordered_vertices) || ordered_vertices[1] == query_id dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids) - return bn.distributions[dist_idx] + current_dist = bn.distributions[dist_idx] + + # Get parent values if it's a conditional distribution + parent_ids = Graphs.inneighbors(bn.graph, query_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + + result = evaluate_distribution(current_dist, parent_values) + return isnothing(result) ? current_dist : result end current_id = ordered_vertices[1] remaining_vertices = ordered_vertices[2:end] - # For current variable, create mixture over its values - components = Distribution[] + # First, get the type of distribution we'll be dealing with + dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids) + current_dist = bn.distributions[dist_idx] + parent_ids = Graphs.inneighbors(bn.graph, query_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + test_dist = evaluate_distribution(current_dist, parent_values) + test_dist = isnothing(test_dist) ? current_dist : test_dist + + # Initialize components with the correct type + if test_dist isa ContinuousUnivariateDistribution + components = Vector{ContinuousUnivariateDistribution}() + else + components = Vector{DiscreteUnivariateDistribution}() + end weights = Float64[] - # Try both values (0 and 1) # TODO: generalize for other values + # Try both values (0 and 1) for value in [0, 1] new_assignments = copy(assignments) new_assignments[bn.names[current_id]] = value - # Get distribution for remaining variables component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments) - println("Components so far: ", components) - println("Current component: ", component) push!(components, component) # Get weight from current node's distribution dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids) - push!(weights, pdf(bn.distributions[dist_idx], value)) + current_dist = bn.distributions[dist_idx] + parent_ids = Graphs.inneighbors(bn.graph, current_id) + parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids] + + dist = evaluate_distribution(current_dist, parent_values) + dist = isnothing(dist) ? current_dist : dist + push!(weights, pdf(dist, value)) end - # Normalize weights weights ./= sum(weights) - return MixtureModel(components, weights) end \ No newline at end of file diff --git a/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl b/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl index 77ab7338..e37c1497 100644 --- a/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl +++ b/test/experimental/ProbabilisticGraphicalModels/test_variable_elimination.jl @@ -6,7 +6,7 @@ Pkg.activate(".") using Test using Distributions using Graphs -using BangBang # This is needed based on the imports +using BangBang # 3. Load JuliaBUGS and its submodule using JuliaBUGS @@ -14,69 +14,95 @@ using JuliaBUGS.ProbabilisticGraphicalModels using JuliaBUGS.ProbabilisticGraphicalModels: BayesianNetwork, add_stochastic_vertex!, - add_deterministic_vertex!, add_edge!, condition, - condition!, - decondition, - ancestral_sampling, - is_conditionally_independent, marginal_distribution, eliminate_variables -# 4. Run the specific test -@testset "Simple Discrete Chain" begin +@testset "Mixed Graph - Variable Elimination" begin bn = BayesianNetwork{Symbol}() - # Simple chain A -> B -> C - add_stochastic_vertex!(bn, :A, Bernoulli(0.7), false) - add_stochastic_vertex!(bn, :B, Bernoulli(0.8), false) - add_stochastic_vertex!(bn, :C, Bernoulli(0.9), false) + # X1 ~ Uniform(0,1) + add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) + + # X2 ~ Bernoulli(X1) + function x2_distribution(x1) + return Bernoulli(x1) + end + add_stochastic_vertex!(bn, :X2, x2_distribution, false) + add_edge!(bn, :X1, :X2) + + # X3 ~ Normal(μ(X2), 1) + function x3_distribution(x2) + return Normal(x2 == 1 ? 10.0 : 2.0, 1.0) + end + add_stochastic_vertex!(bn, :X3, x3_distribution, false) + add_edge!(bn, :X2, :X3) + + # Test graph structure + @test has_edge(bn.graph, 1, 2) # X1 -> X2 + @test has_edge(bn.graph, 2, 3) # X2 -> X3 - add_edge!(bn, :A, :B) - add_edge!(bn, :B, :C) + # Test conditional distributions + # Test X2's distribution given X1 + bn_cond_x1 = condition(bn, Dict(:X1 => 0.7)) + marginal_x2 = marginal_distribution(bn_cond_x1, :X2) + @test marginal_x2 isa Bernoulli + @test mean(marginal_x2) ≈ 0.7 - ordered_vertices = topological_sort_by_dfs(bn.graph) - println(ordered_vertices) - marginal_C = marginal_distribution(bn, :C) - println(marginal_C) + # Test X3's distribution given X2 + bn_cond_x2_0 = condition(bn, Dict(:X2 => 0)) + marginal_x3_0 = marginal_distribution(bn_cond_x2_0, :X3) + @test marginal_x3_0 isa Normal + @test mean(marginal_x3_0) ≈ 2.0 + @test std(marginal_x3_0) ≈ 1.0 + + bn_cond_x2_1 = condition(bn, Dict(:X2 => 1)) + marginal_x3_1 = marginal_distribution(bn_cond_x2_1, :X3) + @test marginal_x3_1 isa Normal + @test mean(marginal_x3_1) ≈ 10.0 + @test std(marginal_x3_1) ≈ 1.0 + + # Test full chain inference + ordered_vertices = [1, 2] # Eliminate X1, then X2 + query_id = 3 # Query X3 + result = eliminate_variables(bn, ordered_vertices, query_id, Dict{Symbol,Any}()) + + # The result should be a mixture of Normal distributions + @test result isa MixtureModel end -# @testset "Mixed Graph - Variable Elimination" begin -# bn = BayesianNetwork{Symbol}() - -# # X1 ~ Uniform(0,1) -# add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) - -# # X2 ~ Bernoulli(X1) -# # We need a function that creates a new Bernoulli distribution based on X1's value -# add_deterministic_vertex!(bn, :X2_dist, x1 -> Bernoulli(x1)) -# add_stochastic_vertex!(bn, :X2, Bernoulli(0.5), false) # Initial dist doesn't matter -# add_edge!(bn, :X1, :X2_dist) -# add_edge!(bn, :X2_dist, :X2) - -# # X3 ~ Normal(μ(X2), 1) -# # Function that creates a new Normal distribution based on X2's value -# add_deterministic_vertex!(bn, :X3_dist, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0)) -# add_stochastic_vertex!(bn, :X3, Normal(0, 1), false) # Initial dist doesn't matter -# add_edge!(bn, :X2, :X3_dist) -# add_edge!(bn, :X3_dist, :X3) -# end - -@testset "Mixed Graph - Variable Elimination" begin +@testset "Marginal Distribution P(X3|X1)" begin bn = BayesianNetwork{Symbol}() # X1 ~ Uniform(0,1) add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false) # X2 ~ Bernoulli(X1) - # The distribution constructor takes the parent value and returns the appropriate distribution - conditional_dist_X2 = x1 -> Bernoulli(x1) - add_stochastic_vertex!(bn, :X2, conditional_dist_X2, false) + add_stochastic_vertex!(bn, :X2, x1 -> Bernoulli(x1), false) add_edge!(bn, :X1, :X2) # X3 ~ Normal(μ(X2), 1) - conditional_dist_X3 = x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0) - add_stochastic_vertex!(bn, :X3, conditional_dist_X3, false) + add_stochastic_vertex!(bn, :X3, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0), false) add_edge!(bn, :X2, :X3) + + # Test P(X3|X1=0.7) + bn_cond = condition(bn, Dict(:X1 => 0.7)) + marginal_x3 = marginal_distribution(bn_cond, :X3) + + @test marginal_x3 isa MixtureModel + @test length(marginal_x3.components) == 2 + @test marginal_x3.components[1] isa Normal + @test marginal_x3.components[2] isa Normal + + # When X1 = 0.7: + # P(X2=0) = 0.3, P(X2=1) = 0.7 + @test marginal_x3.prior.p ≈ [0.3, 0.7] + + # Component means should be 2 and 10 + @test mean(marginal_x3.components[1]) ≈ 2.0 + @test mean(marginal_x3.components[2]) ≈ 10.0 + + # Overall mean should be weighted average + @test mean(marginal_x3) ≈ 2.0 * 0.3 + 10.0 * 0.7 end \ No newline at end of file