Skip to content

Commit

Permalink
finally no errors but the code is error
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 28, 2024
1 parent 6237304 commit fcc4c58
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 87 deletions.
111 changes: 69 additions & 42 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,41 +4,33 @@
A structure representing a Bayesian Network.
"""

struct BayesianNetwork{V,T,F}
# First, modify the BayesianNetwork struct definition
struct BayesianNetwork{V,T}
graph::SimpleDiGraph{T}
"names of the variables in the network"
names::Vector{V}
"mapping from variable names to ids"
names_to_ids::Dict{V,T}
"values of each variable in the network"
values::Dict{V,Any} # TODO: make it a NamedTuple for better performance in the future
"distributions of the stochastic variables"
# A distribution can be either:
# - A fixed distribution (like Uniform(0,1))
# - A function that takes parent values and returns a distribution
distributions::Vector{Union{Distribution,Function}}
"deterministic functions of the deterministic variables"
deterministic_functions::Vector{F}
"ids of the stochastic variables"
stochastic_ids::Vector{T}
"ids of the deterministic variables"
deterministic_ids::Vector{T}
is_stochastic::BitVector
values::Dict{V,Any}
distributions::Vector{Union{Distribution,Function}}
is_observed::BitVector
is_stochastic::BitVector
stochastic_ids::Vector{Int}
deterministic_ids::Vector{Int}
deterministic_functions::Vector{Function}
end

# Then, modify the constructor to match exactly
function BayesianNetwork{V}() where {V}
return BayesianNetwork(
SimpleDiGraph{Int}(), # by default, vertex ids are integers
V[],
Dict{V,Int}(),
Dict{V,Any}(),
Distribution[],
Any[],
Int[],
Int[],
BitVector(),
BitVector(),
return BayesianNetwork{V,Int}(
SimpleDiGraph{Int}(), # graph
V[], # names
Dict{V,Int}(), # names_to_ids
Dict{V,Any}(), # values
Union{Distribution,Function}[], # distributions
BitVector(), # is_observed
BitVector(), # is_stochastic
Int[], # stochastic_ids
Int[], # deterministic_ids
Function[] # deterministic_functions - Added this
)
end

Expand Down Expand Up @@ -127,8 +119,10 @@ function add_stochastic_vertex!(
id = nv(bn.graph)
push!(bn.distributions, dist)
push!(bn.is_observed, is_observed)
push!(bn.is_stochastic, true)
push!(bn.names, name)
bn.names_to_ids[name] = id
push!(bn.stochastic_ids, id)
return id
end

Expand Down Expand Up @@ -358,12 +352,25 @@ function marginal_distribution(bn::BayesianNetwork{V}, query_var::V) where {V}
# Start recursive elimination
return eliminate_variables(bn, ordered_vertices, query_id, Dict{V,Any}())
end
# Helper functions to evaluate distributions
function evaluate_distribution(dist::Distribution, _)
return dist
end

"""
eliminate_variables(bn, ordered_vertices, query_id, assignments)
function evaluate_distribution(dist_func::Function, parent_values)
# Skip evaluation if any parent value is nothing
if any(isnothing, parent_values)
return nothing
end

# If there's only one parent value, pass it directly instead of splatting
if length(parent_values) == 1
return dist_func(parent_values[1])
else
return dist_func(parent_values...)
end
end

Helper function for variable elimination algorithm.
"""
function eliminate_variables(
bn::BayesianNetwork{V},
ordered_vertices::Vector{Int},
Expand All @@ -373,34 +380,54 @@ function eliminate_variables(
# Base case: reached query variable
if isempty(ordered_vertices) || ordered_vertices[1] == query_id
dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids)
return bn.distributions[dist_idx]
current_dist = bn.distributions[dist_idx]

# Get parent values if it's a conditional distribution
parent_ids = Graphs.inneighbors(bn.graph, query_id)
parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids]

result = evaluate_distribution(current_dist, parent_values)
return isnothing(result) ? current_dist : result
end

current_id = ordered_vertices[1]
remaining_vertices = ordered_vertices[2:end]

# For current variable, create mixture over its values
components = Distribution[]
# First, get the type of distribution we'll be dealing with
dist_idx = findfirst(id -> id == query_id, bn.stochastic_ids)
current_dist = bn.distributions[dist_idx]
parent_ids = Graphs.inneighbors(bn.graph, query_id)
parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids]
test_dist = evaluate_distribution(current_dist, parent_values)
test_dist = isnothing(test_dist) ? current_dist : test_dist

# Initialize components with the correct type
if test_dist isa ContinuousUnivariateDistribution
components = Vector{ContinuousUnivariateDistribution}()
else
components = Vector{DiscreteUnivariateDistribution}()
end
weights = Float64[]

# Try both values (0 and 1) # TODO: generalize for other values
# Try both values (0 and 1)
for value in [0, 1]
new_assignments = copy(assignments)
new_assignments[bn.names[current_id]] = value

# Get distribution for remaining variables
component = eliminate_variables(bn, remaining_vertices, query_id, new_assignments)
println("Components so far: ", components)
println("Current component: ", component)
push!(components, component)

# Get weight from current node's distribution
dist_idx = findfirst(id -> id == current_id, bn.stochastic_ids)
push!(weights, pdf(bn.distributions[dist_idx], value))
current_dist = bn.distributions[dist_idx]
parent_ids = Graphs.inneighbors(bn.graph, current_id)
parent_values = [get(assignments, bn.names[pid], nothing) for pid in parent_ids]

dist = evaluate_distribution(current_dist, parent_values)
dist = isnothing(dist) ? current_dist : dist
push!(weights, pdf(dist, value))
end

# Normalize weights
weights ./= sum(weights)

return MixtureModel(components, weights)
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,103 @@ Pkg.activate(".")
using Test
using Distributions
using Graphs
using BangBang # This is needed based on the imports
using BangBang

# 3. Load JuliaBUGS and its submodule
using JuliaBUGS
using JuliaBUGS.ProbabilisticGraphicalModels
using JuliaBUGS.ProbabilisticGraphicalModels:
BayesianNetwork,
add_stochastic_vertex!,
add_deterministic_vertex!,
add_edge!,
condition,
condition!,
decondition,
ancestral_sampling,
is_conditionally_independent,
marginal_distribution,
eliminate_variables
# 4. Run the specific test

@testset "Simple Discrete Chain" begin
@testset "Mixed Graph - Variable Elimination" begin
bn = BayesianNetwork{Symbol}()

# Simple chain A -> B -> C
add_stochastic_vertex!(bn, :A, Bernoulli(0.7), false)
add_stochastic_vertex!(bn, :B, Bernoulli(0.8), false)
add_stochastic_vertex!(bn, :C, Bernoulli(0.9), false)
# X1 ~ Uniform(0,1)
add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false)

# X2 ~ Bernoulli(X1)
function x2_distribution(x1)
return Bernoulli(x1)
end
add_stochastic_vertex!(bn, :X2, x2_distribution, false)
add_edge!(bn, :X1, :X2)

# X3 ~ Normal(μ(X2), 1)
function x3_distribution(x2)
return Normal(x2 == 1 ? 10.0 : 2.0, 1.0)
end
add_stochastic_vertex!(bn, :X3, x3_distribution, false)
add_edge!(bn, :X2, :X3)

# Test graph structure
@test has_edge(bn.graph, 1, 2) # X1 -> X2
@test has_edge(bn.graph, 2, 3) # X2 -> X3

add_edge!(bn, :A, :B)
add_edge!(bn, :B, :C)
# Test conditional distributions
# Test X2's distribution given X1
bn_cond_x1 = condition(bn, Dict(:X1 => 0.7))
marginal_x2 = marginal_distribution(bn_cond_x1, :X2)
@test marginal_x2 isa Bernoulli
@test mean(marginal_x2) 0.7

ordered_vertices = topological_sort_by_dfs(bn.graph)
println(ordered_vertices)
marginal_C = marginal_distribution(bn, :C)
println(marginal_C)
# Test X3's distribution given X2
bn_cond_x2_0 = condition(bn, Dict(:X2 => 0))
marginal_x3_0 = marginal_distribution(bn_cond_x2_0, :X3)
@test marginal_x3_0 isa Normal
@test mean(marginal_x3_0) 2.0
@test std(marginal_x3_0) 1.0

bn_cond_x2_1 = condition(bn, Dict(:X2 => 1))
marginal_x3_1 = marginal_distribution(bn_cond_x2_1, :X3)
@test marginal_x3_1 isa Normal
@test mean(marginal_x3_1) 10.0
@test std(marginal_x3_1) 1.0

# Test full chain inference
ordered_vertices = [1, 2] # Eliminate X1, then X2
query_id = 3 # Query X3
result = eliminate_variables(bn, ordered_vertices, query_id, Dict{Symbol,Any}())

# The result should be a mixture of Normal distributions
@test result isa MixtureModel
end

# @testset "Mixed Graph - Variable Elimination" begin
# bn = BayesianNetwork{Symbol}()

# # X1 ~ Uniform(0,1)
# add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false)

# # X2 ~ Bernoulli(X1)
# # We need a function that creates a new Bernoulli distribution based on X1's value
# add_deterministic_vertex!(bn, :X2_dist, x1 -> Bernoulli(x1))
# add_stochastic_vertex!(bn, :X2, Bernoulli(0.5), false) # Initial dist doesn't matter
# add_edge!(bn, :X1, :X2_dist)
# add_edge!(bn, :X2_dist, :X2)

# # X3 ~ Normal(μ(X2), 1)
# # Function that creates a new Normal distribution based on X2's value
# add_deterministic_vertex!(bn, :X3_dist, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0))
# add_stochastic_vertex!(bn, :X3, Normal(0, 1), false) # Initial dist doesn't matter
# add_edge!(bn, :X2, :X3_dist)
# add_edge!(bn, :X3_dist, :X3)
# end

@testset "Mixed Graph - Variable Elimination" begin
@testset "Marginal Distribution P(X3|X1)" begin
bn = BayesianNetwork{Symbol}()

# X1 ~ Uniform(0,1)
add_stochastic_vertex!(bn, :X1, Uniform(0, 1), false)

# X2 ~ Bernoulli(X1)
# The distribution constructor takes the parent value and returns the appropriate distribution
conditional_dist_X2 = x1 -> Bernoulli(x1)
add_stochastic_vertex!(bn, :X2, conditional_dist_X2, false)
add_stochastic_vertex!(bn, :X2, x1 -> Bernoulli(x1), false)
add_edge!(bn, :X1, :X2)

# X3 ~ Normal(μ(X2), 1)
conditional_dist_X3 = x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0)
add_stochastic_vertex!(bn, :X3, conditional_dist_X3, false)
add_stochastic_vertex!(bn, :X3, x2 -> Normal(x2 == 1 ? 10.0 : 2.0, 1.0), false)
add_edge!(bn, :X2, :X3)

# Test P(X3|X1=0.7)
bn_cond = condition(bn, Dict(:X1 => 0.7))
marginal_x3 = marginal_distribution(bn_cond, :X3)

@test marginal_x3 isa MixtureModel
@test length(marginal_x3.components) == 2
@test marginal_x3.components[1] isa Normal
@test marginal_x3.components[2] isa Normal

# When X1 = 0.7:
# P(X2=0) = 0.3, P(X2=1) = 0.7
@test marginal_x3.prior.p [0.3, 0.7]

# Component means should be 2 and 10
@test mean(marginal_x3.components[1]) 2.0
@test mean(marginal_x3.components[2]) 10.0

# Overall mean should be weighted average
@test mean(marginal_x3) 2.0 * 0.3 + 10.0 * 0.7
end

0 comments on commit fcc4c58

Please sign in to comment.