Skip to content

Commit

Permalink
Inline statically known method errors. (#54972)
Browse files Browse the repository at this point in the history
This replaces the `Expr(:call, ...)` with a call of a new builtin
`Core.throw_methoderror`

This is useful because it makes very clear if something is a static
method error or a plain dynamic dispatch that always errors.
Tools such as AllocCheck or juliac can notice that this is not a genuine
dynamic dispatch, and prevent it from becoming a false positive
compile-time error.

Dependent on #55705

---------

Co-authored-by: Cody Tapscott <[email protected]>
  • Loading branch information
gbaraldi and topolarity authored Sep 17, 2024
1 parent f808606 commit 61c044c
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 65 deletions.
41 changes: 21 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
rettype = exctype = Any
all_effects = Effects()
else
if (matches isa MethodMatches ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches)))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
all_effects = Effects(all_effects; nothrow=false)
exctype = exctype ₚ MethodError
Expand Down Expand Up @@ -275,21 +274,23 @@ struct MethodMatches
applicable::Vector{Any}
info::MethodMatchInfo
valid_worlds::WorldRange
mt::MethodTable
fullmatch::Bool
end
any_ambig(info::MethodMatchInfo) = info.results.ambig
any_ambig(result::MethodLookupResult) = result.ambig
any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
end
any_ambig(m::UnionSplitMethodMatches) = any(any_ambig, m.info.matches)
any_ambig(info::UnionSplitInfo) = any(any_ambig, info.matches)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(info.fullmatches)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
max_union_splitting::Int = InferenceParams(interp).max_union_splitting,
Expand All @@ -307,7 +308,7 @@ is_union_split_eligible(𝕃::AbstractLattice, argtypes::Vector{Any}, max_union_
function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any},
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodMatchInfo[]
infos = MethodLookupResult[]
applicable = Any[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
Expand All @@ -323,29 +324,29 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
if matches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
push!(infos, MethodMatchInfo(matches))
push!(infos, matches)
for m in matches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
mt_found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
mt_found = true
break
end
end
if !found
if !mt_found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
end
info = UnionSplitInfo(infos)
info = UnionSplitInfo(infos, mts, fullmatches)
return UnionSplitMethodMatches(
applicable, applicable_argtypes, info, valid_worlds, mts, fullmatches)
applicable, applicable_argtypes, info, valid_worlds)
end

function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(atype), max_methods::Int)
Expand All @@ -360,10 +361,9 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
info = MethodMatchInfo(matches)
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
return MethodMatches(
matches.matches, info, matches.valid_worlds, mt, fullmatch)
info = MethodMatchInfo(matches, mt, fullmatch)
return MethodMatches(matches.matches, info, matches.valid_worlds)
end

"""
Expand Down Expand Up @@ -584,9 +584,10 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
matches::UnionSplitMethodMatches
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand Down
52 changes: 30 additions & 22 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ struct InliningCase
end

struct UnionSplit
fully_covered::Bool
handled_all_cases::Bool # All possible dispatches are included in the cases
fully_covered::Bool # All handled cases are fully covering
atype::DataType
cases::Vector{InliningCase}
bbs::Vector{Int}
UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(fully_covered, atype, cases, Int[])
UnionSplit(handled_all_cases::Bool, fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
new(handled_all_cases, fully_covered, atype, cases, Int[])
end

struct InliningEdgeTracker
Expand Down Expand Up @@ -215,7 +216,7 @@ end

function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
state::CFGInliningState, params::OptimizationParams)
(; fully_covered, #=atype,=# cases, bbs) = union_split
(; handled_all_cases, fully_covered, #=atype,=# cases, bbs) = union_split
inline_into_block!(state, block_for_inst(ir, idx))
from_bbs = Int[]
delete!(state.split_targets, length(state.new_cfg_blocks))
Expand All @@ -235,7 +236,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
push!(from_bbs, length(state.new_cfg_blocks))
if !(i == length(cases) && fully_covered)
if !(i == length(cases) && (handled_all_cases && fully_covered))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand All @@ -244,7 +245,10 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, union_split::UnionSplit,
end
end
# The edge from the fallback block.
fully_covered || push!(from_bbs, length(state.new_cfg_blocks))
# NOTE This edge is only required for `!handled_all_cases` and not `!fully_covered`,
# since in the latter case we inline `Core.throw_methoderror` into the fallback
# block, which is must-throw, making the subsequent code path unreachable.
!handled_all_cases && push!(from_bbs, length(state.new_cfg_blocks))
# This block will be the block everyone returns to
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx), from_bbs, orig_succs))
join_bb = length(state.new_cfg_blocks)
Expand Down Expand Up @@ -523,7 +527,7 @@ assuming their order stays the same post-discovery in `ml_matches`.
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any},
union_split::UnionSplit, boundscheck::Symbol,
todo_bbs::Vector{Tuple{Int,Int}}, interp::AbstractInterpreter)
(; fully_covered, atype, cases, bbs) = union_split
(; handled_all_cases, fully_covered, atype, cases, bbs) = union_split
stmt, typ, line = compact.result[idx][:stmt], compact.result[idx][:type], compact.result[idx][:line]
join_bb = bbs[end]
pn = PhiNode()
Expand All @@ -538,7 +542,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
cond = true
nparams = fieldcount(atype)
@assert nparams == fieldcount(mtype)
if !(i == ncases && fully_covered)
if !(i == ncases && fully_covered && handled_all_cases)
for i = 1:nparams
aft, mft = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -597,14 +601,18 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::
end
bb += 1
# We're now in the fall through block, decide what to do
if !fully_covered
if !handled_all_cases
ssa = insert_node_here!(compact, NewInstruction(stmt, typ, line))
push!(pn.edges, bb)
push!(pn.values, ssa)
insert_node_here!(compact, NewInstruction(GotoNode(join_bb), Any, line))
finish_current_bb!(compact, 0)
elseif !fully_covered
insert_node_here!(compact, NewInstruction(Expr(:call, GlobalRef(Core, :throw_methoderror), argexprs...), Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
finish_current_bb!(compact, 0)
ncases == 0 && return insert_node_here!(compact, NewInstruction(nothing, Any, line))
end

# We're now in the join block.
return insert_node_here!(compact, NewInstruction(pn, typ, line))
end
Expand Down Expand Up @@ -1348,10 +1356,6 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
handled_all_cases = false
continue
end
local split_fully_covered = false
for (j, match) in enumerate(meth)
Expand Down Expand Up @@ -1392,22 +1396,26 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
handled_all_cases &= handle_any_const_result!(cases,
result, match, argtypes, info, flag, state; allow_typevars=true)
end
if !fully_covered
atype = argtypes_to_type(sig.argtypes)
# We will emit an inline MethodError so we need a backedge to the MethodTable
add_uncovered_edges!(state.edges, info, atype)
end
elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

return cases, (handled_all_cases & fully_covered), joint_effects
return cases, handled_all_cases, fully_covered, joint_effects
end

function handle_call!(todo::Vector{Pair{Int,Any}},
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature,
state::InliningState)
cases = compute_inlining_cases(info, flag, sig, state)
cases === nothing && return nothing
cases, all_covered, joint_effects = cases
cases, handled_all_cases, fully_covered, joint_effects = cases
atype = argtypes_to_type(sig.argtypes)
handle_cases!(todo, ir, idx, stmt, atype, cases, all_covered, joint_effects)
handle_cases!(todo, ir, idx, stmt, atype, cases, handled_all_cases, fully_covered, joint_effects)
end

function handle_match!(cases::Vector{InliningCase},
Expand Down Expand Up @@ -1496,19 +1504,19 @@ function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallIn
end

function handle_cases!(todo::Vector{Pair{Int,Any}}, ir::IRCode, idx::Int, stmt::Expr,
@nospecialize(atype), cases::Vector{InliningCase}, all_covered::Bool,
@nospecialize(atype), cases::Vector{InliningCase}, handled_all_cases::Bool, fully_covered::Bool,
joint_effects::Effects)
# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if all_covered && length(cases) == 1
if fully_covered && handled_all_cases && length(cases) == 1
handle_single_case!(todo, ir, idx, stmt, cases[1].item)
elseif length(cases) > 0
elseif length(cases) > 0 || handled_all_cases
isa(atype, DataType) || return nothing
for case in cases
isa(case.sig, DataType) || return nothing
end
push!(todo, idx=>UnionSplit(all_covered, atype, cases))
push!(todo, idx=>UnionSplit(handled_all_cases, fully_covered, atype, cases))
else
add_flag!(ir[SSAValue(idx)], flags_for_effects(joint_effects))
end
Expand Down
17 changes: 14 additions & 3 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ not a call to a generic function.
"""
struct MethodMatchInfo <: CallInfo
results::MethodLookupResult
mt::MethodTable
fullmatch::Bool
end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
getresult_impl(::MethodMatchInfo, ::Int) = nothing
add_uncovered_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); )

"""
info::UnionSplitInfo <: CallInfo
Expand All @@ -48,20 +51,27 @@ each partition (`info.matches::Vector{MethodMatchInfo}`).
This info is illegal on any statement that is not a call to a generic function.
"""
struct UnionSplitInfo <: CallInfo
matches::Vector{MethodMatchInfo}
matches::Vector{MethodLookupResult}
mts::Vector{MethodTable}
fullmatches::Vector{Bool}
end

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.matches
n += nmatches(mminfo)
n += length(mminfo)
end
return n
end
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
getsplit_impl(info::UnionSplitInfo, idx::Int) = getsplit_impl(info.matches[idx], 1)
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
getresult_impl(::UnionSplitInfo, ::Int) = nothing
function add_uncovered_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
for (mt, fullmatch) in zip(info.mts, info.fullmatches)
!fullmatch && push!(edges, mt, atype)
end
end

abstract type ConstResult end

Expand Down Expand Up @@ -105,6 +115,7 @@ end
nsplit_impl(info::ConstCallInfo) = nsplit(info.call)
getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx)
getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx]
add_uncovered_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_uncovered_edges!(edges, info.call, atype)

"""
info::MethodResultPure <: CallInfo
Expand Down
7 changes: 3 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2983,9 +2983,9 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
if isa(matches, MethodMatches)
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
fully_covering(matches) || add_mt_backedge!(sv, matches.info.mt, atype)
else
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
for (thisfullmatch, mt) in zip(matches.info.fullmatches, matches.info.mts)
thisfullmatch || add_mt_backedge!(sv, mt, atype)
end
end
Expand All @@ -3001,8 +3001,7 @@ function abstract_applicable(interp::AbstractInterpreter, argtypes::Vector{Any},
add_backedge!(sv, edge)
end

if isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
if !fully_covering(matches) || any_ambig(matches)
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
rt = Bool
end
Expand Down
6 changes: 6 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ abstract type CallInfo end

nsplit(info::CallInfo) = nsplit_impl(info)::Union{Nothing,Int}
getsplit(info::CallInfo, idx::Int) = getsplit_impl(info, idx)::MethodLookupResult
add_uncovered_edges!(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = add_uncovered_edges_impl(edges, info, atype)

getresult(info::CallInfo, idx::Int) = getresult_impl(info, idx)

# must implement `nsplit`, `getsplit`, and `add_uncovered_edges!` to opt in to inlining
nsplit_impl(::CallInfo) = nothing
getsplit_impl(::CallInfo, ::Int) = error("unexpected call into `getsplit`")
add_uncovered_edges_impl(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = error("unexpected call into `add_uncovered_edges!`")

# must implement `getresult` to opt in to extended lattice return information
getresult_impl(::CallInfo, ::Int) = nothing

@specialize
9 changes: 9 additions & 0 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ end
CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::NoinlineCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)

function CC.abstract_call(interp::NoinlineInterpreter,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
Expand All @@ -431,6 +432,8 @@ end
@inline function inlined_usually(x, y, z)
return x * y + z
end
foo_split(x::Float64) = 1
foo_split(x::Int) = 2

# check if the inlining algorithm works as expected
let src = code_typed1((Float64,Float64,Float64)) do x, y, z
Expand All @@ -444,6 +447,7 @@ let NoinlineModule = Module()
main_func(x, y, z) = inlined_usually(x, y, z)
@eval NoinlineModule noinline_func(x, y, z) = $inlined_usually(x, y, z)
@eval OtherModule other_func(x, y, z) = $inlined_usually(x, y, z)
@eval NoinlineModule bar_split_error() = $foo_split(Core.compilerbarrier(:type, nothing))

interp = NoinlineInterpreter(Set((NoinlineModule,)))

Expand Down Expand Up @@ -473,6 +477,11 @@ let NoinlineModule = Module()
@test count(isinvoke(:inlined_usually), src.code) == 0
@test count(iscall((src, inlined_usually)), src.code) == 0
end

let src = code_typed1(NoinlineModule.bar_split_error)
@test count(iscall((src, foo_split)), src.code) == 0
@test count(iscall((src, Core.throw_methoderror)), src.code) > 0
end
end

# Make sure that Core.Compiler has enough NamedTuple infrastructure
Expand Down
Loading

0 comments on commit 61c044c

Please sign in to comment.