From d44784d254c256a1cfc92fdb6b7fa8aea6bcd680 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sat, 14 Dec 2024 23:29:11 +0900 Subject: [PATCH] inference: don't allocate `TryCatchFrame` for `compute_trycatch(::IRCode)` `TryCatchFrame` is only required for the abstract interpretation and is not necessary in `compute_trycatch` within slot2ssa.jl. --- Compiler/src/inferencestate.jl | 34 ++++++++++++++++++++++------------ Compiler/src/ssair/slot2ssa.jl | 5 ++--- Compiler/test/inference.jl | 4 ++-- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/Compiler/src/inferencestate.jl b/Compiler/src/inferencestate.jl index 6988e74310fc5..c42e291760c39 100644 --- a/Compiler/src/inferencestate.jl +++ b/Compiler/src/inferencestate.jl @@ -219,16 +219,25 @@ const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required -mutable struct TryCatchFrame +abstract type Handler end +mutable struct TryCatchFrame <: Handler exct scopet const enter_idx::Int scope_uses::Vector{Int} TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx) end +TryCatchFrame(stmt::EnterNode, pc::Int) = TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc) +struct SimpleHandler <: Handler + enter_idx::Int +end +SimpleHandler(::EnterNode, pc::Int) = SimpleHandler(pc) +get_enter_idx(handler::Handler) = get_enter_idx_impl(handler) +get_enter_idx_impl((; enter_idx)::SimpleHandler) = enter_idx +get_enter_idx_impl((; enter_idx)::TryCatchFrame) = enter_idx -struct HandlerInfo - handlers::Vector{TryCatchFrame} +struct HandlerInfo{T<:Handler} + handlers::Vector{T} handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc end @@ -261,7 +270,7 @@ mutable struct InferenceState currbb::Int currpc::Int ip::BitSet#=TODO BoundedMinPrioritySet=# # current active instruction pointers - handler_info::Union{Nothing,HandlerInfo} + handler_info::Union{Nothing,HandlerInfo{TryCatchFrame}} ssavalue_uses::Vector{BitSet} # ssavalue sparsity and restart info # TODO: Could keep this sparsely by doing structural liveness analysis ahead of time. bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet @@ -318,7 +327,7 @@ mutable struct InferenceState currbb = currpc = 1 ip = BitSet(1) # TODO BitSetBoundedMinPrioritySet(1) - handler_info = compute_trycatch(code) + handler_info = compute_trycatch(TryCatchFrame, code) nssavalues = src.ssavaluetypes::Int ssavalue_uses = find_ssavalue_uses(code, nssavalues) nstmts = length(code) @@ -421,10 +430,11 @@ is_inferred(result::InferenceResult) = result.result !== nothing was_reached(sv::InferenceState, pc::Int) = sv.ssavaluetypes[pc] !== NOT_FOUND -compute_trycatch(ir::IRCode) = compute_trycatch(ir.stmts.stmt, ir.cfg.blocks) +compute_trycatch(ir::IRCode) = compute_trycatch(SimpleHandler, ir) +compute_trycatch(Handler::Type{<:Handler}, ir::IRCode) = compute_trycatch(Handler, ir.stmts.stmt, ir.cfg.blocks) """ - compute_trycatch(code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo} + compute_trycatch(Handler::Type{<:Handler}, code, [, bbs]) -> handler_info::Union{Nothing,HandlerInfo{Handler}} Given the code of a function, compute, at every statement, the current try/catch handler, and the current exception stack top. This function returns @@ -433,9 +443,9 @@ a tuple of: 1. `handler_info.handler_at`: A statement length vector of tuples `(catch_handler, exception_stack)`, which are indices into `handlers` - 2. `handler_info.handlers`: A `TryCatchFrame` vector of handlers + 2. `handler_info.handlers`: A `Handler` vector of handlers """ -function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing) +function compute_trycatch(Handler::Type{<:Handler}, code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothing}=nothing) # The goal initially is to record the frame like this for the state at exit: # 1: (enter 3) # == 0 # 3: (expr) # == 1 @@ -454,10 +464,10 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi stmt = code[pc] if isa(stmt, EnterNode) (;handlers, handler_at) = handler_info = - (handler_info === nothing ? HandlerInfo(TryCatchFrame[], fill((0, 0), n)) : handler_info) + (handler_info === nothing ? HandlerInfo{Handler}(Handler[], fill((0, 0), n)) : handler_info) l = stmt.catch_dest (bbs !== nothing) && (l = first(bbs[l].stmts)) - push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc)) + push!(handlers, Handler(stmt, pc)) handler_id = length(handlers) handler_at[pc + 1] = (handler_id, 0) push!(ip, pc + 1) @@ -526,7 +536,7 @@ function compute_trycatch(code::Vector{Any}, bbs::Union{Vector{BasicBlock},Nothi end cur_hand = cur_stacks[1] for i = 1:l - cur_hand = handler_at[handlers[cur_hand].enter_idx][1] + cur_hand = handler_at[get_enter_idx(handlers[cur_hand])][1] end cur_stacks = (cur_hand, cur_stacks[2]) cur_stacks == (0, 0) && break diff --git a/Compiler/src/ssair/slot2ssa.jl b/Compiler/src/ssair/slot2ssa.jl index 6fc87934d3bc5..3213e47ccefa9 100644 --- a/Compiler/src/ssair/slot2ssa.jl +++ b/Compiler/src/ssair/slot2ssa.jl @@ -569,7 +569,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState, end # Record the correct exception handler for all critical sections - handler_info = compute_trycatch(code) + handler_info = compute_trycatch(SimpleHandler, code) phi_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)] live_slots = Vector{Int}[Int[] for _ = 1:length(ir.cfg.blocks)] @@ -801,8 +801,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState, has_pinode[id] = false enter_idx = idx while (handler = gethandler(handler_info, enter_idx)) !== nothing - (; enter_idx) = handler - leave_block = block_for_inst(cfg, (code[enter_idx]::EnterNode).catch_dest) + leave_block = block_for_inst(cfg, (code[get_enter_idx(handler)]::EnterNode).catch_dest) cidx = findfirst((; slot)::NewPhiCNode2->slot_id(slot)==id, new_phic_nodes[leave_block]) if cidx !== nothing node = thisdef ? UpsilonNode(thisval) : UpsilonNode() diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index c8b599adb1323..bed64ad225829 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -4435,8 +4435,8 @@ let x = Tuple{Int,Any}[ #=19=# (0, Expr(:pop_exception, Core.SSAValue(2))) #=20=# (0, Core.ReturnNode(Core.SlotNumber(3))) ] - (;handler_at, handlers) = Compiler.compute_trycatch(last.(x)) - @test map(x->x[1] == 0 ? 0 : handlers[x[1]].enter_idx, handler_at) == first.(x) + (;handler_at, handlers) = Compiler.compute_trycatch(Compiler.SimpleHandler, last.(x)) + @test map(x->x[1] == 0 ? 0 : Compiler.get_enter_idx(handlers[x[1]]), handler_at) == first.(x) end @test only(Base.return_types((Bool,)) do y