Skip to content

Commit

Permalink
improve function for finding generated quantities
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Sep 30, 2024
1 parent 1fc7be9 commit 2e17ddf
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 36 deletions.
67 changes: 34 additions & 33 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,52 @@ end
"""
BUGSGraph
The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for
`MetaGraphsNext.MetaGraph`.
The `BUGSGraph` object represents the graph structure for a BUGS model.
"""
const BUGSGraph = MetaGraph
const BUGSGraph = MetaGraph{
Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64
}

"""
find_generated_vars(g::BUGSGraph)
is_model_parameter(g::BUGSGraph, v::VarName) = g[v].is_stochastic && !g[v].is_observed
is_observation(g::BUGSGraph, v::VarName) = g[v].is_stochastic && g[v].is_observed
is_deterministic(g::BUGSGraph, v::VarName) = !g[v].is_stochastic

function find_generated_quantities_variables(
g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData}
) where {Label,VertexData}
generated_quantities_variables = Set{Label}()
can_reach_observations = Dict{Label,Bool}()

Return all the logical variables without stochastic descendants. The values of these variables
do not affect sampling process. These variables are called "generated quantities" traditionally.
"""
function find_generated_vars(g)
graph_roots = VarName[] # root nodes of the graph
for n in labels(g)
if isempty(outneighbor_labels(g, n))
push!(graph_roots, n)
if !is_observation(g, n)
if !dfs_can_reach_observations(g, n, can_reach_observations)
push!(generated_quantities_variables, n)
end
end
end
return generated_quantities_variables
end

generated_vars = VarName[]
for n in graph_roots
if !g[n].is_stochastic
push!(generated_vars, n) # graph roots that are Logical nodes are generated variables
find_generated_vars_recursive_helper(g, n, generated_vars)
end
function dfs_can_reach_observations(g, n, can_reach_observations)
if haskey(can_reach_observations, n)
return can_reach_observations[n]
end
return generated_vars
end

function find_generated_vars_recursive_helper(g, n, generated_vars)
if n in generated_vars # already visited
return nothing
if is_observation(g, n)
can_reach_observations[n] = true
return true
end
for p in inneighbor_labels(g, n) # parents
if p in generated_vars # already visited
continue
end
if g[p].node_type == Stochastic
continue
end # p is a Logical Node
if !any(x -> g[x].node_type == Stochastic, outneighbor_labels(g, p)) # if the node has stochastic children, it is not a root
push!(generated_vars, p)

can_reach = false
for child in MetaGraphsNext.outneighbor_labels(g, n)
if dfs_can_reach_observations(g, child, can_reach_observations)
can_reach = true
break
end
find_generated_vars_recursive_helper(g, p, generated_vars)
end

can_reach_observations[n] = can_reach
return can_reach
end

"""
Expand Down
79 changes: 79 additions & 0 deletions src/three_color_graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Pkg
Pkg.activate(; temp=true)
Pkg.add(["MetaGraphsNext", "GraphMakie", "Graphs", "GLMakie"])

using MetaGraphsNext
using Graphs
using GLMakie, GraphMakie

##

struct Node
color::Int
end

function generate_three_color_metagraph(num_nodes::Int, p::Float64)
g = MetaGraph(SimpleDiGraph(); label_type = Int, vertex_data_type = Node)

for i in 1:num_nodes
color = rand(1:3)
add_vertex!(g, i, Node(color))
end

for i in 1:num_nodes
for j in 1:num_nodes
if i != j && rand() < p
add_edge!(g, i, j)
end
end
end

return g
end

colors = [:red, :green, :blue]

g = generate_three_color_metagraph(10, 0.3)
while is_cyclic(g.graph)
g = generate_three_color_metagraph(10, 0.3)
end

graphplot(g.graph; node_color = [color_map[g[label_for(g, v)]].color for v in vertices(g.graph)], ilabels = [label_for(g, v) for v in vertices(g.graph)])

t_closure = Graphs.transitiveclosure(g.graph)

graphplot(t_closure; node_color = [colors[g[v].color] for v in MetaGraphsNext.labels(g)])

function get_boundary_nodes(g::MetaGraph)
t_closure = Graphs.transitiveclosure(g.graph)
n = nv(g.graph)
type_vector = zeros(Int, n) # Stores types of vertices

# Precompute types for all vertices
for v_id in vertices(g.graph)
_node = g[v_id] # Assuming labels are the same as vertex IDs
type_vector[v_id] = _node.color
end

vertices_of_type_2 = Int[]
vertices_of_type_3 = Int[]

for v_id in vertices(g.graph)
_type = type_vector[v_id]
if _type == 2 || _type == 3
# Check if any descendants have type 1
has_type1_descendant = any(type_vector[w] == 1 for w in outneighbors(t_closure, v_id))
if !has_type1_descendant
if _type == 2
push!(vertices_of_type_2, v_id)
else
push!(vertices_of_type_3, v_id)
end
end
end
end

return vertices_of_type_2, vertices_of_type_3
end

get_boundary_nodes(g)
56 changes: 55 additions & 1 deletion test/graphs.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,59 @@
using JuliaBUGS:
stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket
stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket, find_generated_quantities_variables

@testset "find_generated_quantities_variables" begin
struct TestNode
id::Int
end

JuliaBUGS.is_model_parameter(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) =
g[v].id == 1
JuliaBUGS.is_observation(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) =
g[v].id == 2
JuliaBUGS.is_deterministic(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) =
g[v].id == 3

function generate_random_dag(num_nodes::Int, p::Float64=0.3)
graph = SimpleGraph(num_nodes)
for i in 1:num_nodes
for j in 1:num_nodes
if i != j && rand() < p
add_edge!(graph, i, j)
end
end
end

graph = Graphs.random_orientation_dag(graph) # ensure the random graph is a DAG
vertices_description = [i => TestNode(rand(1:3)) for i in 1:nv(graph)]
edges_description = [Tuple(e) => nothing for e in Graphs.edges(graph)]
return MetaGraph(graph, vertices_description, edges_description)
end

# `transitiveclosure` has time complexity O(|E|⋅|V|), not fit for large graphs
# but easy to implement and understand, here we use it for reference
function find_generated_quantities_variables_with_transitive_closure(
g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData}
) where {Label,VertexData}
_transitive_closure = Graphs.transitiveclosure(g.graph)
generated_quantities_variables = Set{Label}()
for v_id in vertices(g.graph)
if !JuliaBUGS.is_observation(g, v_id)
if all(
!Base.Fix1(JuliaBUGS.is_observation, g), outneighbors(_transitive_closure, v_id)
)
push!(generated_quantities_variables, MetaGraphsNext.label_for(g, v_id))
end
end
end

return generated_quantities_variables
end

@testset "random DAG with $num_nodes nodes and $p probability of edge" for num_nodes in [10, 20, 100, 500, 1000], p in [0.1, 0.3, 0.5]
g = generate_random_dag(num_nodes, p)
@test find_generated_quantities_variables(g) == find_generated_quantities_variables_with_transitive_closure(g)
end
end

test_model = @bugs begin
a ~ dnorm(f, c)
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ using AdvancedHMC
using AdvancedMH
using Bijectors
using Distributions
using Graphs
using MetaGraphsNext
using Graphs, MetaGraphsNext
using LinearAlgebra
using LogDensityProblems
using LogDensityProblemsAD
Expand Down

0 comments on commit 2e17ddf

Please sign in to comment.