Skip to content

Commit

Permalink
Make an inference hot-path slightly faster (#44421)
Browse files Browse the repository at this point in the history
This aims to improve performance of inference slightly by removing
a dynamic dispatch from calls to `widenwrappedconditional`, which
appears in various hot paths and showed up in profiling of inference.

There's two changes here:

1. Improve inlining for calls to functions of the form
```
f(x::Int) = 1
f(@nospecialize(x::Any)) = 2
```
Previously, we would peel of the `x::Int` case and then
generate a dynamic dispatch for the `x::Any` case. After
this change, we directly emit an `:invoke` for the `x::Any`
case (as well as enabling inlining of it in general).

2. Refactor `widenwrappedconditional` itself to avoid a signature
with a union in it, since ironically union splitting cannot currently
deal with that (it can only split unions if they're manifest in the
call arguments).
  • Loading branch information
Keno authored Mar 3, 2022
1 parent ea1b9cf commit 96d6d86
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 30 deletions.
73 changes: 52 additions & 21 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
push!(from_bbs, length(state.new_cfg_blocks))
# TODO: Right now we unconditionally generate a fallback block
# in case of subtyping errors - This is probably unnecessary.
if i != length(cases) || (!fully_covered || !params.trust_inference)
if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand Down Expand Up @@ -481,7 +481,8 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
cond = true
aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector
@assert length(aparams) == length(mparams)
if i != length(cases) || !fully_covered || !params.trust_inference
if i != length(cases) || !fully_covered ||
(!params.trust_inference && isdispatchtuple(cases[i].sig))
for i in 1:length(aparams)
a, m = aparams[i], mparams[i]
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -538,7 +539,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
bb += 1
# We're now in the fall through block, decide what to do
if fully_covered
if !params.trust_inference
if !params.trust_inference && isdispatchtuple(cases[end].sig)
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
insert_node_here!(compact, NewInstruction(e, Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
Expand Down Expand Up @@ -1170,7 +1171,10 @@ function analyze_single_call!(
cases = InliningCase[]
local only_method = nothing # keep track of whether there is one matching method
local meth::MethodLookupResult
local fully_covered = true
local handled_all_cases = true
local any_covers_full = false
local revisit_idx = nothing

for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1179,7 +1183,7 @@ function analyze_single_call!(
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
fully_covered = false
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
Expand All @@ -1192,16 +1196,43 @@ function analyze_single_call!(
only_method = false
end
end
for match in meth
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
fully_covered &= match.fully_covers
for (j, match) in enumerate(meth)
any_covers_full |= match.fully_covers
if !isdispatchtuple(match.spec_types)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
end
end
end

# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple

atype = argtypes_to_type(argtypes)
if length(cases) == 0 && only_method isa Method
if handled_all_cases && revisit_idx !== nothing
# If there's only one case that's not a dispatchtuple, we can
# still unionsplit by visiting all the other cases first.
# This is useful for code like:
# foo(x::Int) = 1
# foo(@nospecialize(x::Any)) = 2
# where we where only a small number of specific dispatchable
# cases are split off from an ::Any typed fallback.
(i, j) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple.
# -- But don't try it if we already tried to handle the match in the revisit_idx
# case, because that'll (necessarily) be the same method.
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
Expand All @@ -1213,10 +1244,10 @@ function analyze_single_call!(
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return nothing
push!(cases, InliningCase(match.spec_types, item))
fully_covered = match.fully_covers
any_covers_full = handled_all_cases = match.fully_covers
end

handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params)
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
end

# similar to `analyze_single_call!`, but with constant results
Expand All @@ -1227,7 +1258,8 @@ function handle_const_call!(
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local fully_covered = true
local handled_all_cases = true
local any_covers_full = false
local j = 0
for i in 1:length(infos)
meth = infos[i].results
Expand All @@ -1237,22 +1269,22 @@ function handle_const_call!(
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
fully_covered = false
handled_all_cases = false
continue
end
for match in meth
j += 1
result = results[j]
any_covers_full |= match.fully_covers
if isa(result, ConstResult)
case = const_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, InferenceResult)
fully_covered &= handle_inf_result!(result, argtypes, flag, state, cases)
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
else
@assert result === nothing
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
handled_all_cases &= isdispatchtuple(match.spec_types) && handle_match!(match, argtypes, flag, state, cases)
end
fully_covered &= match.fully_covers
end
end

Expand All @@ -1265,17 +1297,16 @@ function handle_const_call!(
validate_sparams(mi.sparam_vals) || return nothing
item === nothing && return nothing
push!(cases, InliningCase(mi.specTypes, item))
fully_covered = atype <: mi.specTypes
any_covers_full = handled_all_cases = atype <: mi.specTypes
end

handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params)
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase})
spec_types = match.spec_types
isdispatchtuple(spec_types) || return false
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return false
_any(case->case.sig === spec_types, cases) && return true
Expand Down
18 changes: 10 additions & 8 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,17 @@ end
@inline tchanged(@nospecialize(n), @nospecialize(o)) = o === NOT_FOUND || (n !== NOT_FOUND && !(n o))
@inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n::VarState, o::VarState)))

widenconditional(@nospecialize typ) = typ
function widenconditional(typ::AnyConditional)
if typ.vtype === Union{}
return Const(false)
elseif typ.elsetype === Union{}
return Const(true)
else
return Bool
function widenconditional(@nospecialize typ)
if isa(typ, AnyConditional)
if typ.vtype === Union{}
return Const(false)
elseif typ.elsetype === Union{}
return Const(true)
else
return Bool
end
end
return typ
end
widenconditional(t::LimitedAccuracy) = error("unhandled LimitedAccuracy")

Expand Down
9 changes: 9 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,12 @@ end
let src = code_typed1(f44200)
@test count(x -> isa(x, Core.PiNode), src.code) == 0
end

# Test that peeling off one case from (::Any) doesn't introduce
# a dynamic dispatch.
@noinline f_peel(x::Int) = Base.inferencebarrier(1)
@noinline f_peel(@nospecialize(x::Any)) = Base.inferencebarrier(2)
g_call_peel(x) = f_peel(x)
let src = code_typed1(g_call_peel, Tuple{Any})
@test count(isinvoke(:f_peel), src.code) == 2
end
2 changes: 1 addition & 1 deletion test/worlds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ f_gen265(x::Type{Int}) = 3
# intermediate worlds by later additions to the method table that
# would have capped those specializations if they were still valid
f26506(@nospecialize(x)) = 1
g26506(x) = f26506(x[1])
g26506(x) = Base.inferencebarrier(f26506)(x[1])
z = Any["ABC"]
f26506(x::Int) = 2
g26506(z) # Places an entry for f26506(::String) in mt.name.cache
Expand Down

0 comments on commit 96d6d86

Please sign in to comment.