Skip to content

Commit

Permalink
optimizer: run SROA multiple times to handle more nested loads
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Dec 22, 2021
1 parent bceef47 commit 738df81
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 43 deletions.
121 changes: 88 additions & 33 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])

compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)

function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
field = stmt.args[3]
try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) =
try_compute_field(ir, stmt.args[3])

function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(field))
# fields are usually literals, handle them manually
if isa(field, QuoteNode)
field = field.value
Expand All @@ -44,8 +46,7 @@ function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr
return nothing
end
end
isa(field, Union{Int, Symbol}) || return nothing
return field
return isa(field, Union{Int, Symbol}) ? field : nothing
end

function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
Expand Down Expand Up @@ -167,7 +168,7 @@ function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospec
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
callback = (@nospecialize(x), @nospecialize(idx)) -> false)
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand Down Expand Up @@ -335,10 +336,29 @@ struct LiftedValue
end
const LiftedLeaves = IdDict{Any, Union{Nothing,LiftedValue}}

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

mutable struct NestedLoads
maybe::Union{Nothing,SPCSet}
NestedLoads() = new(nothing)
end
function record_nested_load!(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && (maybe = nested_loads.maybe = SPCSet())
push!(maybe::SPCSet, pc)
end
function is_nested_load(nested_loads::NestedLoads, pc::Int)
maybe = nested_loads.maybe
maybe === nothing && return false
return pc in maybe::SPCSet
end

# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
function lift_leaves!(compact::IncrementalCompact, leaves::Vector{Any},
@nospecialize(result_t), field::Int, nested_loads::NestedLoads)
# For every leaf, the lifted value
lifted_leaves = LiftedLeaves()
maybe_undef = false
Expand Down Expand Up @@ -388,11 +408,19 @@ function lift_leaves(compact::IncrementalCompact,
ocleaf = simple_walk(compact, ocleaf)
end
ocdef, _ = walk_to_def(compact, ocleaf)
if isexpr(ocdef, :new_opaque_closure) && isa(field, Int) && 1 field length(ocdef.args)-5
if isexpr(ocdef, :new_opaque_closure) && 1 field length(ocdef.args)-5
lift_arg!(compact, leaf, cache_key, ocdef, 5+field, lifted_leaves)
continue
end
return nothing
elseif is_known_call(def, getfield, compact)
if isa(leaf, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(def.args[2], compact)))
if ismutabletype(struct_typ)
record_nested_load!(nested_loads, leaf.id)
end
end
return nothing
else
typ = argextype(leaf, compact)
if !isa(typ, Const)
Expand Down Expand Up @@ -586,7 +614,7 @@ function perform_lifting!(compact::IncrementalCompact,
end
val = lifted_val.x
if isa(val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
callback = (@nospecialize(x), @nospecialize(idx)) -> true
val = simple_walk(compact, val, callback)
end
push!(new_node.values, val)
Expand Down Expand Up @@ -617,10 +645,6 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}

"""
sroa_pass!(ir::IRCode) -> newir::IRCode
Expand All @@ -639,10 +663,11 @@ its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of succeeding dead code elimination.
"""
function sroa_pass!(ir::IRCode)
function sroa_pass!(ir::IRCode, optional_opts::Bool = true)
compact = IncrementalCompact(ir)
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
nested_loads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Immutable`
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
Expand Down Expand Up @@ -670,7 +695,7 @@ function sroa_pass!(ir::IRCode)
preserved_arg = stmt.args[pidx]
isa(preserved_arg, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
callback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand Down Expand Up @@ -698,7 +723,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, defidx) do
SPCSet(), SSADefUse()
end
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
end
Expand All @@ -708,16 +735,17 @@ function sroa_pass!(ir::IRCode)
compact[idx] = form_new_preserves(stmt, preserved, new_preserves)
end
continue
# TODO: This isn't the best place to put these
elseif is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
continue
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
continue
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
else
if optional_opts
# TODO: This isn't the best place to put these
if is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
elseif is_known_call(stmt, (===), compact)
lift_comparison!(compact, idx, stmt, lifting_cache)
# elseif is_known_call(stmt, isa, compact)
# TODO do a similar optimization as `lift_comparison!` for `===`
end
end
continue
end

Expand All @@ -743,7 +771,7 @@ function sroa_pass!(ir::IRCode)
if ismutabletype(struct_typ)
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
callback = function (@nospecialize(x), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand All @@ -753,7 +781,9 @@ function sroa_pass!(ir::IRCode)
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
mid, defuse = get!(defuses, def.id) do
SPCSet(), SSADefUse()
end
if is_setfield
push!(defuse.defs, idx)
else
Expand All @@ -775,7 +805,7 @@ function sroa_pass!(ir::IRCode)
isempty(leaves) && continue

result_t = argextype(SSAValue(idx), compact)
lifted_result = lift_leaves(compact, result_t, field, leaves)
lifted_result = lift_leaves!(compact, leaves, result_t, field, nested_loads)
lifted_result === nothing && continue
lifted_leaves, any_undef = lifted_result

Expand Down Expand Up @@ -811,21 +841,25 @@ function sroa_pass!(ir::IRCode)
used_ssas = copy(compact.used_ssas)
simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1)
ir = complete(compact)
sroa_mutables!(ir, defuses, used_ssas)
return ir
return sroa_mutables!(ir, defuses, used_ssas, nested_loads)
else
simple_dce!(compact)
return complete(compact)
end
end

function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
function sroa_mutables!(ir::IRCode,
defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int},
nested_loads::NestedLoads)
# Compute domtree, needed below, now that we have finished compacting the IR.
# This needs to be after we iterate through the IR with `IncrementalCompact`
# because removing dead blocks can invalidate the domtree.
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)

for (idx, (intermediaries, defuse)) in defuses
nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable`
local any_eliminated = false
# NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield`
for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true)
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
Expand All @@ -840,7 +874,19 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
nleaves == nuses_total || continue
# Find the type for this allocation
defexpr = ir[SSAValue(idx)]
isexpr(defexpr, :new) || continue
isa(defexpr, Expr) || continue
if !isexpr(defexpr, :new)
if is_known_call(defexpr, getfield, ir)
val = defexpr.args[2]
if isa(val, SSAValue)
struct_typ = unwrap_unionall(widenconst(argextype(val, ir)))
if ismutabletype(struct_typ)
record_nested_load!(nested_mloads, idx)
end
end
end
continue
end
newidx = idx
typ = ir.stmts[newidx][:type]
if isa(typ, UnionAll)
Expand Down Expand Up @@ -910,6 +956,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse
# Now go through all uses and rewrite them
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
if !any_eliminated
any_eliminated |= (is_nested_load(nested_loads, stmt) ||
is_nested_load(nested_mloads, stmt))
end
end
if !isbitstype(ftyp)
if preserve_uses !== nothing
Expand Down Expand Up @@ -946,6 +996,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse

@label skip
end
if any_eliminated
return sroa_pass!(compact!(ir), false)
else
return ir
end
end

function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any})
Expand Down
51 changes: 41 additions & 10 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[
struct ImmutableXYZ; x; y; z; end
mutable struct MutableXYZ; x; y; z; end

struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end

# should optimize away very basic cases
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
Expand Down Expand Up @@ -198,9 +201,8 @@ let src = code_typed1((Bool,Bool,Any,Any)) do c1, c2, x, y
@test any(isnew, src.code)
end

# should include a simple alias analysis
struct ImmutableOuter{T}; x::T; y::T; z::T; end
mutable struct MutableOuter{T}; x::T; y::T; z::T; end
# alias analysis
# --------------
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = ImmutableOuter(xyz, xyz, xyz)
Expand All @@ -227,9 +229,11 @@ let src = code_typed1((Any,Any,Any)) do x, y, z
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end

# FIXME our analysis isn't yet so powerful at this moment: may be unable to handle nested objects well
# OK: mutable(immutable(...)) case
# FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until
# any nested mutable `getfield` calls become no longer eliminatable:
# it's probably not the most efficient option and we may want to introduce some sort of
# alias analysis and eliminates all the loads at once.
# mutable(immutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
t = (xyz,)
Expand Down Expand Up @@ -260,21 +264,48 @@ let # this is a simple end to end test case, which demonstrates allocation elimi
# compiled code for `simple_sroa`, otherwise everything can be folded even without SROA
@test @allocated(simple_sroa(s)) == 0
end
# FIXME: immutable(mutable(...)) case
# immutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = ImmutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
# FIXME: mutable(mutable(...)) case
# mutable(mutable(...)) case
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
outer = MutableOuter(xyz, xyz, xyz)
outer.x.x, outer.y.y, outer.z.z
end
@test_broken !any(isnew, src.code)
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let src = code_typed1((Any,Any,Any)) do x, y, z
xyz = MutableXYZ(x, y, z)
inner = MutableOuter(xyz, xyz, xyz)
outer = MutableOuter(inner, inner, inner)
outer.x.x.x, outer.y.y.y, outer.z.z.z
end
@test !any(isnew, src.code)
@test any(src.code) do @nospecialize x
iscall((src, tuple), x) &&
x.args[2:end] == Any[#=x=# Core.Argument(2), #=y=# Core.Argument(3), #=y=# Core.Argument(4)]
end
end
let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it should be able
# to fully eliminate this insanely nested example
src = code_typed1((Int,)) do x
(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref(Ref((x))))))))))))[][][][][][][][][][]
end
@test !any(isnew, src.code)
end

# should work nicely with inlining to optimize away a complicated case
Expand Down

0 comments on commit 738df81

Please sign in to comment.