Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ale/3.0 proof #223

Open
wants to merge 7 commits into
base: ale/3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ end
Base.hash(a::IdKey, h::UInt) = xor(a.val, h)
Base.:(==)(a::IdKey, b::IdKey) = a.val == b.val

include("proof.jl")

"""
EGraph{ExpressionType,Analysis}

Expand All @@ -114,12 +116,24 @@ See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
for implementation details.
"""
mutable struct EGraph{ExpressionType,Analysis}
# TODO use Base.@kwdef without clashing methods below
"stores the equality relations over e-class ids"
uf::UnionFind
"map from eclass id to eclasses"
classes::Dict{IdKey,EClass{Analysis}}
"hashcons mapping e-nodes to their e-class id"
"hashcons mapping e-node hashes to their e-class id"
memo::Dict{VecExpr,Id}
"""
Stores the original e-nodes at the index of their uncanonical id.
The uncanonical id of an e-node, is the id of the e-class it was originally added to.
Since new eclasses are created with a fresh id every time a node is added to the e-graph,
that id will be the uncanonical id for a given e-node. The e-class id will instead change
to the canonical one, if the newly added e-class containing a single node is merged
to another e-class, which happens at the merge phase of the application equality saturation step.
"""
nodes::Vector{VecExpr}
"If proofs are enabled, this holds the `EGraphProof` additional unionfind."
proof::Union{EGraphProof,Nothing}
"Hashcons the constants in the e-graph"
constants::Dict{UInt64,Any}
"Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass."
Expand All @@ -139,11 +153,13 @@ end
EGraph(expr)
Construct an EGraph from a starting symbolic expression `expr`.
"""
function EGraph{ExpressionType,Analysis}(; needslock::Bool = false) where {ExpressionType,Analysis}
function EGraph{ExpressionType,Analysis}(; needslock::Bool = false, proof::Bool = false) where {ExpressionType,Analysis}
EGraph{ExpressionType,Analysis}(
UnionFind(),
Dict{IdKey,EClass{Analysis}}(),
Dict{VecExpr,Id}(),
VecExpr[],
proof ? EGraphProof() : nothing,
Dict{UInt64,Any}(),
Pair{VecExpr,Id}[],
UniqueQueue{Pair{VecExpr,Id}}(),
Expand Down Expand Up @@ -215,6 +231,17 @@ function Base.show(io::IO, g::EGraph)
end


function print_proof(g::EGraph)
# Print memo
println("explain_find:")
println.(g.proof.explain_find)
println("uncanon_memo: ")
for (n, id) in g.proof.uncanon_memo
println("\t", to_expr(g, n), " => ", reinterpret(Int, id))
end
end
export print_proof

"""
Returns the canonical e-class id for a given e-class.
"""
Expand Down Expand Up @@ -280,6 +307,8 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
end

id = push!(g.uf) # create new singleton eclass
@assert length(g.nodes) == id - 1
push!(g.nodes, n)

if v_isexpr(n)
for c_id in v_children(n)
Expand All @@ -295,6 +324,10 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
modify!(g, eclass)
push!(g.pending, n => id)

if !isnothing(g.proof)
add!(g.proof, n, id, id)
end

return id
end

Expand Down Expand Up @@ -345,25 +378,34 @@ end

"""
Given an [`EGraph`](@ref) and two e-class ids, set
the two e-classes as equal.
the two e-classes as equal. `rule_idx` argument is optional, it's used in
proof production to justify why two e-classes were merged together.
By default `rule_idx` is equal to 0, and it means that if proofs are enabled,
then the union was performed by congruence closure invariant maintenance (rebuilding).
"""
function Base.union!(
g::EGraph{ExpressionType,AnalysisType},
enode_id1::Id,
enode_id2::Id,
rule_idx::Int = 0,
)::Bool where {ExpressionType,AnalysisType}
g.clean = false

id_1 = IdKey(find(g, enode_id1))
id_2 = IdKey(find(g, enode_id2))

id_1 == id_2 && return false
# TODO if ids already equal should add an alternate rewrite call to proof.

# Make sure class 2 has fewer parents
if length(g.classes[id_1].parents) < length(g.classes[id_2].parents)
id_1, id_2 = id_2, id_1
end

if !isnothing(g.proof)
union!(g.proof, enode_id1, enode_id2, rule_idx)
end

union!(g.uf, id_1.val, id_2.val)

eclass_2 = pop!(g.classes, id_2)::EClass
Expand Down Expand Up @@ -434,7 +476,8 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
if haskey(g.memo, node)
old_class_id = g.memo[node]
g.memo[node] = eclass_id
did_something = union!(g, old_class_id, eclass_id)
# any_new_rhs should be false
did_something = union!(g, old_class_id, eclass_id, 0) # By congruence
# TODO unique! can node dedup be moved here? compare performance
# did_something && unique!(g[eclass_id].nodes)
n_unions += did_something
Expand Down
158 changes: 158 additions & 0 deletions src/EGraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
export ProofConnection, ProofNode, EGraphProof, find_flat_proof

mutable struct ProofConnection
"""
Justification can be
- 0 if the connection is justified by congruence
- Positive integer is index of the rule in theory, applied left-to-right.
- Negative integer is same as above, applied right-to-left.
The absolute value is thus the rule id.
"""
justification::Int
# Next is equal to itself on leaves of the proof tree
# i.e. only the identity (congruence) is a valid proof
next::Id
current::Id
end

function Base.show(io::IO, p::ProofConnection)
j = abs(p.justification)
p.justification == 0 && return print(io, "($(p.current) ≡ $(p.next))")
p.justification < 0 && return print(io, "($(p.current) <-$j- $(p.next))")
print(io, "($(p.current) -$j-> $(p.next))")
end


mutable struct ProofNode
existence_node::Id
# TODO is this the parent in the unionfind?
parent_connection::ProofConnection
# TODO Always includes parent ??????
neighbours::Vector{ProofConnection}
end

function Base.show(io::IO, p::ProofNode)
print(io, "ProofNode(")
print(io, p.existence_node, ", ")
print(io, p.parent_connection, ", ")
print(io, p.neighbours, ")")
end


Base.@kwdef struct EGraphProof
explain_find::Vector{ProofNode} = ProofNode[]
uncanon_memo::Dict{VecExpr,Id} = Dict{VecExpr,Id}()
end

# TODO find better name for existence_node and set
function add!(proof::EGraphProof, n::VecExpr, set::Id, existence_node::Id)
# Insert in the uncanonical memo
# TODO explain why
proof.uncanon_memo[n] = set

# New proof node does not have any neighbours
# Parent connection is by congruence, to the same id
proof_node = ProofNode(existence_node, ProofConnection(0, set, set), ProofConnection[])
push!(proof.explain_find, proof_node)
set
end

# Returns true if it did something
function make_leader(proof::EGraphProof, node::Id)::Bool
proof_node = proof.explain_find[node]
# Next is equal to itself on leaves of the proof tree
# i.e. only the identity (congruence) is a valid proof
# TODO we should change the type
next = proof_node.parent_connection.next
next == node && return false

make_leader(proof, next)
# You need to re-fetch it if there's a circular proof?
# TODO adrian please expand.
proof_node = proof.explain_find[node]
old_parent_connection = proof_node.parent_connection
# Reverse the justification
new_parent_connection = ProofConnection(-old_parent_connection.justification, node, old_parent_connection.next)

proof.explain_find[next].parent_connection = new_parent_connection

true
end

function Base.union!(proof::EGraphProof, node1::Id, node2::Id, rule_idx::Int)
# TODO maybe should have extra argument called `rhs_new` in egg that is true when called from
# application of rules where the instantiation of the rhs creates new e-classes
# TODO if new_rhs set_existance_reason of node2 to node1

# Make node1 the root
make_leader(proof, node1)

proof_node1 = proof.explain_find[node1]
proof_node2 = proof.explain_find[node2]

proof.explain_find[node1].parent_connection.next = node2

pconnection = ProofConnection(abs(rule_idx), node2, node1)
other_pconnection = ProofConnection(-(abs(rule_idx)), node1, node2)


push!(proof_node1.neighbours, pconnection)
push!(proof_node2.neighbours, other_pconnection)

# TODO WAT???
proof_node1.parent_connection = pconnection
end

@inline isroot(pn::ProofNode) = isroot(pn.parent_connection)
@inline isroot(pc::ProofConnection) = pc.current === pc.next

function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)
# We're doing a lowest common ancestor search.
# We cache the IDs we have seen
seen_set = Set{Id}()
# Store the nodes seen from node1 and node2 in order
walk_from1 = ProofNode[]
walk_from2 = ProofNode[]

# No existence_node would ever have id 0
lca = UInt(0)
curr = proof.explain_find[node1]

# Walk up to the root
while true
push!(seen_set, curr.existence_node)
isroot(curr) ? break : push!(walk_from1, curr)
curr = proof.explain_find[curr.parent_connection.next]
end

curr = proof.explain_find[node2]
@show curr
# Walks up until an element of seen_set or root is found.
while true
println("WALKING 2")
@show curr.existence_node
@show seen_set
if curr.existence_node in seen_set
lca = curr.existence_node
@show lca
break
end

isroot(curr) ? break : push!(walk_from2, curr)
curr = proof.explain_find[curr.parent_connection.next]
end

ret = ProofNode[]
@show lca
# There's no LCA => there's no proof.
lca == 0 && return ret

for w in walk_from1
push!(ret, w)
w.existence_node == lca && break
end

# TODO maybe reverse
append!(ret, walk_from2)
ret
end
2 changes: 1 addition & 1 deletion src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ function eqsat_apply!(
break
end

res.l !== 0 && res.r !== 0 && union!(g, res.l, res.r)
res.l !== 0 && res.r !== 0 && union!(g, res.l, res.r, rule_idx)
end
if params.goal(g)
@debug "Goal reached"
Expand Down
1 change: 1 addition & 0 deletions test/egraphs/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ end
params = SaturationParams(timeout = 15)
saturate!(g, t, params)
extr = extract!(g, astsize)
@show extr
@test extr == :((12 * a) * b) ||
extr == :(12 * (a * b)) ||
extr == :(12 * (b * a)) ||
Expand Down
79 changes: 79 additions & 0 deletions test/egraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Metatheory, Test

@testset begin
"Basic proofs by hand"
g = EGraph(; proof = true)

id_a = addexpr!(g, :a)

# print_proof(g)

id_b = addexpr!(g, :b)

union!(g, id_a, id_b, 1)

print_proof(g)

proof = find_flat_proof(g.proof, id_a, id_b)
@test length(proof) == 1

proof = find_flat_proof(g.proof, id_b, id_a)
@test length(proof) == 1


id_c = addexpr!(g, :c)
union!(g, id_b, id_c, 2)

print_proof(g)
proof = find_flat_proof(g.proof, id_a, id_b)
@test length(proof) == 2

id_d = addexpr!(g, :d)

union!(g, id_a, id_d, 3)
print_proof(g)

proof = find_flat_proof(g.proof, id_a, id_d)
@test length(proof) == 1


proof = find_flat_proof(g.proof, id_c, id_d)
@test length(proof) == 3

id_e = addexpr!(g, :e)
@test isempty(find_flat_proof(g.proof, id_a, id_e))
end

@testset "Basic rewriting proofs" begin
r = @rule f(~x) --> g(~x)
g = EGraph(; proof = true)
id_a = addexpr!(g, :a)
id_fa = addexpr!(g, :(f(a)))
id_ga = addexpr!(g, :(g(a)))

saturate!(g, RewriteRule[r])

g

print_proof(g)
proof = find_flat_proof(g.proof, id_fa, id_ga)
@test length(proof) == 1

# =====================

r = @rule :x == :y
g = EGraph(; proof = true)
id_x = addexpr!(g, :x)
id_y = addexpr!(g, :y)
id_fx = addexpr!(g, :(f(x)))
id_fy = addexpr!(g, :(f(y)))

saturate!(g, RewriteRule[r])

g

print_proof(g)
proof = find_flat_proof(g.proof, id_fx, id_fy)
@test length(proof) == 1
@test only(proof).parent_connection.justification === 0 # by congruence
end