Skip to content

Commit

Permalink
inference: don't allocate TryCatchFrame for `compute_trycatch(::IRC…
Browse files Browse the repository at this point in the history
…ode)`

`TryCatchFrame` is only required for the abstract interpretation and is
not necessary in `compute_trycatch` within slot2ssa.jl.
  • Loading branch information
aviatesk committed Dec 14, 2024
1 parent fe5ed17 commit ae129c6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
33 changes: 20 additions & 13 deletions Compiler/src/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,25 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required

mutable struct TryCatchFrame
abstract type Handler end
mutable struct TryCatchFrame <: Handler
exct
scopet
const enter_idx::Int
scope_uses::Vector{Int}
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
end
TryCatchFrame(stmt::EnterNode, pc::Int) = TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc)
struct SimpleHandler <: Handler
enter_idx::Int
end
SimpleHandler(::EnterNode, pc::Int) = SimpleHandler(pc)
get_enter_idx(handler::Handler) = get_enter_idx_impl(handler)
get_enter_idx_impl((; enter_idx)::SimpleHandler) = enter_idx
get_enter_idx_impl((; enter_idx)::TryCatchFrame) = enter_idx

struct HandlerInfo
handlers::Vector{TryCatchFrame}
struct HandlerInfo{T<:Handler}
handlers::Vector{T}
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

Expand Down Expand Up @@ -261,7 +270,7 @@ mutable struct InferenceState
currbb::Int
currpc::Int
ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers
handler_info::Union{Nothing,HandlerInfo}
handler_info::Union{Nothing,HandlerInfo{TryCatchFrame}}
ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
Expand Down Expand Up @@ -318,7 +327,7 @@ mutable struct InferenceState

currbb = currpc = 1
ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1)
handler_info = compute_trycatch(code)
handler_info = compute_trycatch(TryCatchFrame, code)
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
Expand Down Expand Up @@ -421,10 +430,8 @@ is_inferred(result::InferenceResult) = result.result !== nothing

was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND

compute_trycatch(ir::IRCode) = compute_trycatch(ir.stmts.stmt, ir.cfg.blocks)

"""
compute_trycatch(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo}
compute_trycatch(Handler::Type{<:Handler}, code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo{Handler}}
Given the code of a function, compute, at every statement, the current
try/catch handler, and the current exception stack top. This function returns
Expand All @@ -433,9 +440,9 @@ a tuple of:
1. `handler_info.handler_at`: A statement length vector of tuples
`(catch_handler, exception_stack)`, which are indices into `handlers`
2. `handler_info.handlers`: A `TryCatchFrame` vector of handlers
2. `handler_info.handlers`: A `Handler` vector of handlers
"""
function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing)
function compute_trycatch(Handler::Type{<:Handler}, code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing)
# The goal initially is to record the frame like this for the state at exit:
# 1: (enter 3) # == 0
# 3: (expr) # == 1
Expand All @@ -454,10 +461,10 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi
stmt = code[pc]
if isa(stmt, EnterNode)
(;handlers, handler_at) = handler_info =
(handler_info === nothing ? HandlerInfo(TryCatchFrame[], fill((0, 0), n)) : handler_info)
(handler_info === nothing ? HandlerInfo{Handler}(Handler[], fill((0, 0), n)) : handler_info)
l = stmt.catch_dest
(bbs !== nothing) && (l = first(bbs[l].stmts))
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
push!(handlers, Handler(stmt, pc))
handler_id = length(handlers)
handler_at[pc + 1] = (handler_id, 0)
push!(ip, pc + 1)
Expand Down Expand Up @@ -526,7 +533,7 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi
end
cur_hand = cur_stacks[1]
for i = 1:l
cur_hand = handler_at[handlers[cur_hand].enter_idx][1]
cur_hand = handler_at[get_enter_idx(handlers[cur_hand])][1]
end
cur_stacks = (cur_hand, cur_stacks[2])
cur_stacks == (0, 0) && break
Expand Down
5 changes: 2 additions & 3 deletions Compiler/src/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
end

# Record the correct exception handler for all critical sections
handler_info = compute_trycatch(code)
handler_info = compute_trycatch(SimpleHandler, code)

phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)]
Expand Down Expand Up @@ -801,8 +801,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
has_pinode[id] = false
enter_idx = idx
while (handler = gethandler(handler_info, enter_idx)) !== nothing
(; enter_idx) = handler
leave_block = block_for_inst(cfg, (code[enter_idx]::EnterNode).catch_dest)
leave_block = block_for_inst(cfg, (code[get_enter_idx(handler)]::EnterNode).catch_dest)
cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block])
if cidx !== nothing
node = thisdef ? UpsilonNode(thisval) : UpsilonNode()
Expand Down
4 changes: 2 additions & 2 deletions Compiler/test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4435,8 +4435,8 @@ let x = Tuple{Int,Any}[
#=19=# (0, Expr(:pop_exception, Core.SSAValue(2)))
#=20=# (0, Core.ReturnNode(Core.SlotNumber(3)))
]
(;handler_at, handlers) = Compiler.compute_trycatch(last.(x))
@test map(x->x[1] == 0 ? 0 : handlers[x[1]].enter_idx, handler_at) == first.(x)
(;handler_at, handlers) = Compiler.compute_trycatch(Compiler.SimpleHandler, last.(x))
@test map(x->x[1] == 0 ? 0 : Compiler.get_enter_idx(handlers[x[1]]), handler_at) == first.(x)
end

@test only(Base.return_types((Bool,)) do y
Expand Down

0 comments on commit ae129c6

Please sign in to comment.