diff --git a/src/graphs.jl b/src/graphs.jl index 8064642b4..568d4cc2b 100644 --- a/src/graphs.jl +++ b/src/graphs.jl @@ -10,51 +10,52 @@ end """ BUGSGraph -The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for -`MetaGraphsNext.MetaGraph`. +The `BUGSGraph` object represents the graph structure for a BUGS model. """ -const BUGSGraph = MetaGraph +const BUGSGraph = MetaGraph{ + Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64 +} -""" - find_generated_vars(g::BUGSGraph) +is_model_parameter(g::BUGSGraph, v::VarName) = g[v].is_stochastic && !g[v].is_observed +is_observation(g::BUGSGraph, v::VarName) = g[v].is_stochastic && g[v].is_observed +is_deterministic(g::BUGSGraph, v::VarName) = !g[v].is_stochastic + +function find_generated_quantities_variables( + g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData} +) where {Label,VertexData} + generated_quantities_variables = Set{Label}() + can_reach_observations = Dict{Label,Bool}() -Return all the logical variables without stochastic descendants. The values of these variables -do not affect sampling process. These variables are called "generated quantities" traditionally. -""" -function find_generated_vars(g) - graph_roots = VarName[] # root nodes of the graph for n in labels(g) - if isempty(outneighbor_labels(g, n)) - push!(graph_roots, n) + if !is_observation(g, n) + if !dfs_can_reach_observations(g, n, can_reach_observations) + push!(generated_quantities_variables, n) + end end end + return generated_quantities_variables +end - generated_vars = VarName[] - for n in graph_roots - if !g[n].is_stochastic - push!(generated_vars, n) # graph roots that are Logical nodes are generated variables - find_generated_vars_recursive_helper(g, n, generated_vars) - end +function dfs_can_reach_observations(g, n, can_reach_observations) + if haskey(can_reach_observations, n) + return can_reach_observations[n] end - return generated_vars -end -function find_generated_vars_recursive_helper(g, n, generated_vars) - if n in generated_vars # already visited - return nothing + if is_observation(g, n) + can_reach_observations[n] = true + return true end - for p in inneighbor_labels(g, n) # parents - if p in generated_vars # already visited - continue - end - if g[p].node_type == Stochastic - continue - end # p is a Logical Node - if !any(x -> g[x].node_type == Stochastic, outneighbor_labels(g, p)) # if the node has stochastic children, it is not a root - push!(generated_vars, p) + + can_reach = false + for child in MetaGraphsNext.outneighbor_labels(g, n) + if dfs_can_reach_observations(g, child, can_reach_observations) + can_reach = true + break end - find_generated_vars_recursive_helper(g, p, generated_vars) end + + can_reach_observations[n] = can_reach + return can_reach end """ diff --git a/src/three_color_graph.jl b/src/three_color_graph.jl new file mode 100644 index 000000000..6f71b9662 --- /dev/null +++ b/src/three_color_graph.jl @@ -0,0 +1,79 @@ +using Pkg +Pkg.activate(; temp=true) +Pkg.add(["MetaGraphsNext", "GraphMakie", "Graphs", "GLMakie"]) + +using MetaGraphsNext +using Graphs +using GLMakie, GraphMakie + +## + +struct Node + color::Int +end + +function generate_three_color_metagraph(num_nodes::Int, p::Float64) + g = MetaGraph(SimpleDiGraph(); label_type = Int, vertex_data_type = Node) + + for i in 1:num_nodes + color = rand(1:3) + add_vertex!(g, i, Node(color)) + end + + for i in 1:num_nodes + for j in 1:num_nodes + if i != j && rand() < p + add_edge!(g, i, j) + end + end + end + + return g +end + +colors = [:red, :green, :blue] + +g = generate_three_color_metagraph(10, 0.3) +while is_cyclic(g.graph) + g = generate_three_color_metagraph(10, 0.3) +end + +graphplot(g.graph; node_color = [color_map[g[label_for(g, v)]].color for v in vertices(g.graph)], ilabels = [label_for(g, v) for v in vertices(g.graph)]) + +t_closure = Graphs.transitiveclosure(g.graph) + +graphplot(t_closure; node_color = [colors[g[v].color] for v in MetaGraphsNext.labels(g)]) + +function get_boundary_nodes(g::MetaGraph) + t_closure = Graphs.transitiveclosure(g.graph) + n = nv(g.graph) + type_vector = zeros(Int, n) # Stores types of vertices + + # Precompute types for all vertices + for v_id in vertices(g.graph) + _node = g[v_id] # Assuming labels are the same as vertex IDs + type_vector[v_id] = _node.color + end + + vertices_of_type_2 = Int[] + vertices_of_type_3 = Int[] + + for v_id in vertices(g.graph) + _type = type_vector[v_id] + if _type == 2 || _type == 3 + # Check if any descendants have type 1 + has_type1_descendant = any(type_vector[w] == 1 for w in outneighbors(t_closure, v_id)) + if !has_type1_descendant + if _type == 2 + push!(vertices_of_type_2, v_id) + else + push!(vertices_of_type_3, v_id) + end + end + end + end + + return vertices_of_type_2, vertices_of_type_3 +end + +get_boundary_nodes(g) diff --git a/test/graphs.jl b/test/graphs.jl index e99d77081..3b6a55cf0 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -1,5 +1,59 @@ using JuliaBUGS: - stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket + stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket, find_generated_quantities_variables + +@testset "find_generated_quantities_variables" begin + struct TestNode + id::Int + end + + JuliaBUGS.is_model_parameter(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) = + g[v].id == 1 + JuliaBUGS.is_observation(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) = + g[v].id == 2 + JuliaBUGS.is_deterministic(g::MetaGraph{Int,<:SimpleDiGraph,Int,TestNode}, v::Int) = + g[v].id == 3 + + function generate_random_dag(num_nodes::Int, p::Float64=0.3) + graph = SimpleGraph(num_nodes) + for i in 1:num_nodes + for j in 1:num_nodes + if i != j && rand() < p + add_edge!(graph, i, j) + end + end + end + + graph = Graphs.random_orientation_dag(graph) # ensure the random graph is a DAG + vertices_description = [i => TestNode(rand(1:3)) for i in 1:nv(graph)] + edges_description = [Tuple(e) => nothing for e in Graphs.edges(graph)] + return MetaGraph(graph, vertices_description, edges_description) + end + + # `transitiveclosure` has time complexity O(|E|⋅|V|), not fit for large graphs + # but easy to implement and understand, here we use it for reference + function find_generated_quantities_variables_with_transitive_closure( + g::MetaGraph{Int,<:SimpleDiGraph,Label,VertexData} + ) where {Label,VertexData} + _transitive_closure = Graphs.transitiveclosure(g.graph) + generated_quantities_variables = Set{Label}() + for v_id in vertices(g.graph) + if !JuliaBUGS.is_observation(g, v_id) + if all( + !Base.Fix1(JuliaBUGS.is_observation, g), outneighbors(_transitive_closure, v_id) + ) + push!(generated_quantities_variables, MetaGraphsNext.label_for(g, v_id)) + end + end + end + + return generated_quantities_variables + end + + @testset "random DAG with $num_nodes nodes and $p probability of edge" for num_nodes in [10, 20, 100, 500, 1000], p in [0.1, 0.3, 0.5] + g = generate_random_dag(num_nodes, p) + @test find_generated_quantities_variables(g) == find_generated_quantities_variables_with_transitive_closure(g) + end +end test_model = @bugs begin a ~ dnorm(f, c) diff --git a/test/runtests.jl b/test/runtests.jl index 0d12f58ab..4e9505c53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,8 +11,7 @@ using AdvancedHMC using AdvancedMH using Bijectors using Distributions -using Graphs -using MetaGraphsNext +using Graphs, MetaGraphsNext using LinearAlgebra using LogDensityProblems using LogDensityProblemsAD