diff --git a/src/Lean/Compiler/IR/ExpandResetReuse.lean b/src/Lean/Compiler/IR/ExpandResetReuse.lean index 1aa90d9fbb1f..a927b3adf69d 100644 --- a/src/Lean/Compiler/IR/ExpandResetReuse.lean +++ b/src/Lean/Compiler/IR/ExpandResetReuse.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura, Anton Lorenzen +Authors: Leonardo de Moura -/ prelude import Lean.Compiler.IR.CompilerM @@ -9,10 +9,7 @@ import Lean.Compiler.IR.NormIds import Lean.Compiler.IR.FreeVars namespace Lean.IR.ExpandResetReuse - -/-- Mapping from variable to projections. - We use this in reuse specialization to avoid setting fields that are already set. --/ +/-- Mapping from variable to projections -/ abbrev ProjMap := HashMap VarId Expr namespace CollectProjMap abbrev Collector := ProjMap → ProjMap @@ -42,15 +39,7 @@ structure Context where abbrev Mask := Array (Option VarId) -/-- Auxiliary function for eraseProjIncFor. - Traverse bs left to right to find pairs of - ``` - let z := proj[i] y - inc z n c - ``` - If `n=1` remove the `inc` instruction and if `n>1` replace `inc z n c` with `inc z (n-1) c`. - Additionally, we track the variables `z` that have been found in the mask. --/ +/-- Auxiliary function for eraseProjIncFor -/ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask := let done (_ : Unit) := (bs ++ keep.reverse, mask) let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b) @@ -82,7 +71,7 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke | _ => done () | _ => done () -/-- Try to erase one `inc` instruction on projections of `y` occurring in the tail of `bs`. +/-- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`. Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/ def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask := eraseProjIncForAux y bs (mkArray n none) #[] @@ -93,8 +82,6 @@ abbrev M := ReaderT Context (StateM Nat) def mkFreshJoinPoint : M JoinPointId := modifyGet fun n => ({ idx := n }, n + 1) -/-- If the reused cell is unique, we can reuse its memory. - Then we have to manually release all fields that are not live. -/ def releaseUnreadFields (y : VarId) (mask : Mask) : M (FnBody → FnBody) := mask.size.foldM (init := id) fun i b => match mask.get! i with @@ -103,6 +90,9 @@ def releaseUnreadFields (y : VarId) (mask : Mask) : M (FnBody → FnBody) := let fld ← mkFresh pure (FnBody.vdecl fld IRType.object (Expr.proj i y) ∘ (FnBody.dec fld 1 true false ∘ b)) +def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody := + zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b + /-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/ def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool := match y with @@ -112,12 +102,6 @@ def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool := | _ => false | _ => false -/-- Set fields of `y` to `zs`. We avoid assignments that are already set. -/ -def setFields (ctx : Context) (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody := - zs.size.fold (init := b) fun i b => - if isSelfSet ctx y i (zs.get! i) then b - else FnBody.set y i (zs.get! i) b - /-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/ def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool := match ctx.projMap.find? y with @@ -130,12 +114,29 @@ def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Boo | some (Expr.sproj m j w) => n == m && j == i && w == x | _ => false -/-- The empty reuse token returned for non-unique cells. -/ +/-- Remove unnecessary `set/uset/sset` operations -/ +partial def removeSelfSet (ctx : Context) : FnBody → FnBody + | FnBody.set x i y b => + if isSelfSet ctx x i y then removeSelfSet ctx b + else FnBody.set x i y (removeSelfSet ctx b) + | FnBody.uset x i y b => + if isSelfUSet ctx x i y then removeSelfSet ctx b + else FnBody.uset x i y (removeSelfSet ctx b) + | FnBody.sset x n i y t b => + if isSelfSSet ctx x n i y then removeSelfSet ctx b + else FnBody.sset x n i y t (removeSelfSet ctx b) + | FnBody.setTag x c b => FnBody.setTag x c (removeSelfSet ctx b) + | e => e + +/- New version of `mkFastPath` that uses drop-specialisation and reuse-specialisation + instead of expanding the resets to avoid code blowup. +-/ + def null := Expr.lit (LitVal.num 0) -/-- Create a new join point, where the declaration `v` obtains a function that will generate - a jump to the join point with the variable as an argument. We optimize the case, where - the binding is just a return and float it into the declaration. -/ +/- Create a new join point, where the declaration `v` obtains a function that will generate + a jump to the join point with the variable as an argument. We optimize the case, where + the binding is just a return and float it into the declaration. -/ def mkJoin (x : VarId) (t : IRType) (b : FnBody) (v : (VarId → FnBody) → FnBody) : M FnBody := match b with | FnBody.ret _ => @@ -143,39 +144,69 @@ def mkJoin (x : VarId) (t : IRType) (b : FnBody) (v : (VarId → FnBody) → FnB | _ => do let j ← mkFreshJoinPoint let z ← mkFresh - -- We use the given VarId for the joinpoint, which avoids the need to rename the tokens. - -- This is especially important since otherwise the ProjectionMap would find projections - -- out of the old variables and thus break reuse specialization. return FnBody.jdecl j #[mkParam z false t] (b.replaceVar x z) (v (fun z => mkJmp j #[Arg.var z])) -/-- Reuse specialization. -/ -def specializeReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do +/- Reuse specialisation -/ +def tryReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do let ctx ← read let null? ← mkFresh - let newAlloc ← mkFresh + let z ← mkFresh mkJoin reused t b fun jmp => (FnBody.vdecl null? IRType.uint8 (Expr.isNull token) (mkIf null? - (FnBody.vdecl newAlloc t (Expr.ctor c xs) - (jmp newAlloc)) - ((if u then FnBody.setTag token c.cidx else id) - (setFields ctx token xs - (jmp token))))) + (FnBody.vdecl z t (Expr.ctor c xs) + (jmp z)) + (removeSelfSet ctx + ((if u then FnBody.setTag token c.cidx else id) + (setFields token xs + (jmp token)))))) + +/- Apply reuse specialisation for a reuse instruction -/ +partial def reuseToTryReuse (x y : VarId) : FnBody → M FnBody + | FnBody.dec z n c p b => + if x == z then return FnBody.del y b + else do + let b ← reuseToTryReuse x y b + return FnBody.dec z n c p b + | FnBody.vdecl z t v b => + match v with + | Expr.reuse w c u zs => + if x == w then + tryReuse z y c u t zs b + else do + let b ← reuseToTryReuse x y b + return FnBody.vdecl z t v b + | _ => do + let b ← reuseToTryReuse x y b + return FnBody.vdecl z t v b + | FnBody.case tid z zType alts => do + let alts ← alts.mapM fun alt => alt.mmodifyBody (reuseToTryReuse x y) + return FnBody.case tid z zType alts + | FnBody.jdecl j xs v b => do + let v ← reuseToTryReuse x y v + let b ← reuseToTryReuse x y b + return FnBody.jdecl j xs v b + | e => + if e.isTerminal then return e + else do + let (instr, b) := e.split + let b ← reuseToTryReuse x y b + return instr.setBody b -/-- Increment all live children and decrement y. -/ def adjustReferenceCountsSlowPath (y : VarId) (mask : Mask) (b : FnBody) := let b := FnBody.dec y 1 true false b mask.foldl (init := b) fun b m => match m with | some z => FnBody.inc z 1 true false b | none => b -/- Drop specialization -/ -def specializeReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBody := do +/- Drop specialisation -/ +def tryReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBody := do let shared? ← mkFresh let z2 ← mkFresh let fastPath ← releaseUnreadFields oldAlloc mask + let b ← reuseToTryReuse token token b mkJoin token IRType.object b fun jmp => (FnBody.vdecl shared? IRType.uint8 (Expr.isShared oldAlloc) (mkIf shared? @@ -184,38 +215,33 @@ def specializeReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBo (jmp z2))) (fastPath (jmp oldAlloc)))) -partial def searchAndSpecialize : FnBody → Array FnBody → Array VarId → M FnBody - | FnBody.vdecl x _ (Expr.reset n y) b, bs, tokens => do - let (bs, mask) := eraseProjIncFor n y bs - let b ← searchAndSpecialize b #[] (tokens.push x) - let b ← specializeReset x y mask b - return reshape bs b - | FnBody.vdecl z t (Expr.reuse w c u zs) b, bs, tokens => do - let b ← searchAndSpecialize b #[] tokens - let b ← specializeReuse z w c u t zs b - return reshape bs b - | FnBody.dec z n c p b, bs, tokens => - if Array.contains tokens z then return FnBody.del z b - else do - let b ← searchAndSpecialize b #[] tokens - return reshape bs (FnBody.dec z n c p b) - | FnBody.jdecl j xs v b, bs, tokens => do - let v ← searchAndSpecialize v #[] tokens - let b ← searchAndSpecialize b #[] tokens +mutual +partial def specialize (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do + let (bs, mask) := eraseProjIncFor n y bs + let b ← tryReset x y mask b + searchAndExpand b bs + +partial def searchAndExpand : FnBody → Array FnBody → M FnBody + | FnBody.vdecl x _ (Expr.reset n y) b, bs => + specialize bs x n y b + | FnBody.jdecl j xs v b, bs => do + let v ← searchAndExpand v #[] + let b ← searchAndExpand b #[] return reshape bs (FnBody.jdecl j xs v b) - | FnBody.case tid x xType alts, bs, tokens => do - let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndSpecialize b #[] tokens + | FnBody.case tid x xType alts, bs => do + let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndExpand b #[] return reshape bs (FnBody.case tid x xType alts) - | b, bs, tokens => + | b, bs => if b.isTerminal then return reshape bs b - else searchAndSpecialize b.body (push bs b) tokens + else searchAndExpand b.body (push bs b) +end def main (d : Decl) : Decl := match d with | .fdecl (body := b) .. => let m := mkProjMap d let nextIdx := d.maxIndex + 1 - let bNew := (searchAndSpecialize b #[] #[] { projMap := m }).run' nextIdx + let bNew := (searchAndExpand b #[] { projMap := m }).run' nextIdx d.updateBody! bNew | d => d