From 24cfe6e2cc779c2b2d05dae3d2dce1da1424970a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:36:54 +0900 Subject: [PATCH] inference: refine branched `Conditional` types (#55216) Separated from JuliaLang/julia#40880. This subtle adjustment allows for more accurate type inference in the following kind of cases: ```julia function condition_object_update2(x) cond = x isa Int if cond # `cond` is known to be `Const(true)` within this branch return !cond ? nothing : x # ::Int else return cond ? nothing : 1 # ::Int end end @test Base.infer_return_type(condition_object_update2, (Any,)) == Int ``` Also cleans up typelattice.jl a bit. --- base/compiler/abstractinterpretation.jl | 60 +++++++++++------ base/compiler/typelattice.jl | 23 ------- test/compiler/inference.jl | 87 ++++++++++++------------- 3 files changed, 82 insertions(+), 88 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index f34b1e716a271..d441d994034a5 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -3067,7 +3067,7 @@ end @inline function abstract_eval_basic_statement(interp::AbstractInterpreter, @nospecialize(stmt), pc_vartable::VarTable, frame::InferenceState) if isa(stmt, NewvarNode) - changes = StateUpdate(stmt.slot, VarState(Bottom, true), pc_vartable, false) + changes = StateUpdate(stmt.slot, VarState(Bottom, true), false) return BasicStmtChange(changes, nothing, Union{}) elseif !isa(stmt, Expr) (; rt, exct) = abstract_eval_statement(interp, stmt, pc_vartable, frame) @@ -3082,7 +3082,7 @@ end end lhs = stmt.args[1] if isa(lhs, SlotNumber) - changes = StateUpdate(lhs, VarState(rt, false), pc_vartable, false) + changes = StateUpdate(lhs, VarState(rt, false), false) elseif isa(lhs, GlobalRef) handle_global_assignment!(interp, frame, lhs, rt) elseif !isa(lhs, SSAValue) @@ -3092,7 +3092,7 @@ end elseif hd === :method fname = stmt.args[1] if isa(fname, SlotNumber) - changes = StateUpdate(fname, VarState(Any, false), pc_vartable, false) + changes = StateUpdate(fname, VarState(Any, false), false) end return BasicStmtChange(changes, nothing, Union{}) elseif (hd === :code_coverage_effect || ( @@ -3242,7 +3242,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) @goto branch elseif isa(stmt, GotoIfNot) condx = stmt.cond - condxslot = ssa_def_slot(condx, frame) + condslot = ssa_def_slot(condx, frame) condt = abstract_eval_value(interp, condx, currstate, frame) if condt === Bottom ssavaluetypes[currpc] = Bottom @@ -3250,10 +3250,10 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) @goto find_next_bb end orig_condt = condt - if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condxslot, SlotNumber) + if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condslot, SlotNumber) # if this non-`Conditional` object is a slot, we form and propagate # the conditional constraint on it - condt = Conditional(condxslot, Const(true), Const(false)) + condt = Conditional(condslot, Const(true), Const(false)) end condval = maybe_extract_const_bool(condt) nothrow = (condval !== nothing) || โŠ‘(๐•ƒแตข, orig_condt, Bool) @@ -3299,21 +3299,31 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) # We continue with the true branch, but process the false # branch here. if isa(condt, Conditional) - else_change = conditional_change(๐•ƒแตข, currstate, condt.elsetype, condt.slot) + else_change = conditional_change(๐•ƒแตข, currstate, condt, #=then_or_else=#false) if else_change !== nothing - false_vartable = stoverwrite1!(copy(currstate), else_change) + elsestate = copy(currstate) + stoverwrite1!(elsestate, else_change) + elseif condslot isa SlotNumber + elsestate = copy(currstate) else - false_vartable = currstate + elsestate = currstate end - changed = update_bbstate!(๐•ƒแตข, frame, falsebb, false_vartable) - then_change = conditional_change(๐•ƒแตข, currstate, condt.thentype, condt.slot) + if condslot isa SlotNumber # refine the type of this conditional object itself for this else branch + stoverwrite1!(elsestate, condition_object_change(currstate, condt, condslot, #=then_or_else=#false)) + end + else_changed = update_bbstate!(๐•ƒแตข, frame, falsebb, elsestate) + then_change = conditional_change(๐•ƒแตข, currstate, condt, #=then_or_else=#true) + thenstate = currstate if then_change !== nothing - stoverwrite1!(currstate, then_change) + stoverwrite1!(thenstate, then_change) + end + if condslot isa SlotNumber # refine the type of this conditional object itself for this then branch + stoverwrite1!(thenstate, condition_object_change(currstate, condt, condslot, #=then_or_else=#true)) end else - changed = update_bbstate!(๐•ƒแตข, frame, falsebb, currstate) + else_changed = update_bbstate!(๐•ƒแตข, frame, falsebb, currstate) end - if changed + if else_changed handle_control_backedge!(interp, frame, currpc, stmt.dest) push!(W, falsebb) end @@ -3412,13 +3422,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) nothing end -function conditional_change(๐•ƒแตข::AbstractLattice, state::VarTable, @nospecialize(typ), slot::Int) - vtype = state[slot] +function conditional_change(๐•ƒแตข::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool) + vtype = currstate[condt.slot] oldtyp = vtype.typ - if iskindtype(typ) + newtyp = then_or_else ? condt.thentype : condt.elsetype + if iskindtype(newtyp) # this code path corresponds to the special handling for `isa(x, iskindtype)` check # implemented within `abstract_call_builtin` - elseif โŠ‘(๐•ƒแตข, ignorelimited(typ), ignorelimited(oldtyp)) + elseif โŠ‘(๐•ƒแตข, ignorelimited(newtyp), ignorelimited(oldtyp)) # approximate test for `typ โˆฉ oldtyp` being better than `oldtyp` # since we probably formed these types with `typesubstract`, # the comparison is likely simple @@ -3428,9 +3439,18 @@ function conditional_change(๐•ƒแตข::AbstractLattice, state::VarTable, @nospecia if oldtyp isa LimitedAccuracy # typ is better unlimited, but we may still need to compute the tmeet with the limit # "causes" since we ignored those in the comparison - typ = tmerge(๐•ƒแตข, typ, LimitedAccuracy(Bottom, oldtyp.causes)) + newtyp = tmerge(๐•ƒแตข, newtyp, LimitedAccuracy(Bottom, oldtyp.causes)) end - return StateUpdate(SlotNumber(slot), VarState(typ, vtype.undef), state, true) + return StateUpdate(SlotNumber(condt.slot), VarState(newtyp, vtype.undef), true) +end + +function condition_object_change(currstate::VarTable, condt::Conditional, + condslot::SlotNumber, then_or_else::Bool) + vtype = currstate[slot_id(condslot)] + newcondt = Conditional(condt.slot, + then_or_else ? condt.thentype : Union{}, + then_or_else ? Union{} : condt.elsetype) + return StateUpdate(condslot, VarState(newcondt, vtype.undef), false) end # make as much progress on `frame` as possible (by handling cycles) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 5987c30be2b91..52f4a6922328c 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -159,7 +159,6 @@ end struct StateUpdate var::SlotNumber vtype::VarState - state::VarTable conditional::Bool end @@ -724,28 +723,6 @@ function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional: return nothing end -function stupdate!(lattice::AbstractLattice, state::VarTable, changes::StateUpdate) - changed = false - changeid = slot_id(changes.var) - for i = 1:length(state) - if i == changeid - newtype = changes.vtype - else - newtype = changes.state[i] - end - invalidated = invalidate_slotwrapper(newtype, changeid, changes.conditional) - if invalidated !== nothing - newtype = invalidated - end - oldtype = state[i] - if schanged(lattice, newtype, oldtype) - state[i] = smerge(lattice, oldtype, newtype) - changed = true - end - end - return changed -end - function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable) changed = false for i = 1:length(state) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 8b0c6669ffb89..ed7b177a1e99d 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2151,78 +2151,75 @@ end @testset "branching on conditional object" begin # simple - @test Base.return_types((Union{Nothing,Int},)) do a + @test Base.infer_return_type((Union{Nothing,Int},)) do a b = a === nothing return b ? 0 : a # ::Int - end == Any[Int] + end == Int # can use multiple times (as far as the subject of condition hasn't changed) - @test Base.return_types((Union{Nothing,Int},)) do a + @test Base.infer_return_type((Union{Nothing,Int},)) do a b = a === nothing c = b ? 0 : a # c::Int d = !b ? a : 0 # d::Int return c, d # ::Tuple{Int,Int} - end == Any[Tuple{Int,Int}] + end == Tuple{Int,Int} # should invalidate old constraint when the subject of condition has changed - @test Base.return_types((Union{Nothing,Int},)) do a + @test Base.infer_return_type((Union{Nothing,Int},)) do a cond = a === nothing r1 = cond ? 0 : a # r1::Int a = 0 r2 = cond ? a : 1 # r2::Int, not r2::Union{Nothing,Int} return r1, r2 # ::Tuple{Int,Int} - end == Any[Tuple{Int,Int}] + end == Tuple{Int,Int} end # https://github.com/JuliaLang/julia/issues/42090#issuecomment-911824851 # `PartialStruct` shouldn't wrap `Conditional` -let M = Module() - @eval M begin - struct BePartialStruct - val::Int - cond - end - end - - rt = @eval M begin - Base.return_types((Union{Nothing,Int},)) do a - cond = a === nothing - obj = $(Expr(:new, M.BePartialStruct, 42, :cond)) - r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional) - a = $(gensym(:anyvar))::Any - r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here) - return r1, r2 # ::Tuple{Union{Nothing,Int},Any} - end |> only - end - @test rt == Tuple{Union{Nothing,Int},Any} +struct BePartialStruct + val::Int + cond +end +@test Tuple{Union{Nothing,Int},Any} == @eval Base.infer_return_type((Union{Nothing,Int},)) do a + cond = a === nothing + obj = $(Expr(:new, BePartialStruct, 42, :cond)) + r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional) + a = $(gensym(:anyvar))::Any + r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here) + return r1, r2 # ::Tuple{Union{Nothing,Int},Any} end # make sure we never form nested `Conditional` (https://github.com/JuliaLang/julia/issues/46207) -@test Base.return_types((Any,)) do a +@test Base.infer_return_type((Any,)) do a c = isa(a, Integer) 42 === c ? :a : "b" -end |> only === String -@test Base.return_types((Any,)) do a +end == String +@test Base.infer_return_type((Any,)) do a c = isa(a, Integer) c === 42 ? :a : "b" -end |> only === String +end == String -@testset "conditional constraint propagation from non-`Conditional` object" begin - @test Base.return_types((Bool,)) do b - if b - return !b ? nothing : 1 # ::Int - else - return 0 - end - end == Any[Int] - - @test Base.return_types((Any,)) do b - if b - return b # ::Bool - else - return nothing - end - end == Any[Union{Bool,Nothing}] +function condition_object_update1(cond) + if cond # `cond` is known to be `Const(true)` within this branch + return !cond ? nothing : 1 # ::Int + else + return cond ? nothing : 1 # ::Int + end +end +function condition_object_update2(x) + cond = x isa Int + if cond # `cond` is known to be `Const(true)` within this branch + return !cond ? nothing : x # ::Int + else + return cond ? nothing : 1 # ::Int + end +end +@testset "state update for condition object" begin + # refine the type of condition object into constant boolean values on branching + @test Base.infer_return_type(condition_object_update1, (Bool,)) == Int + @test Base.infer_return_type(condition_object_update1, (Any,)) == Int + # refine even when their original type is `Conditional` + @test Base.infer_return_type(condition_object_update2, (Any,)) == Int end @testset "`from_interprocedural!`: translate inter-procedural information" begin