Skip to content

Commit

Permalink
feat: add comments and remove quadratic behaviour in ExpandResetReuse
Browse files Browse the repository at this point in the history
  • Loading branch information
anfelor committed Apr 1, 2024
1 parent 57dd7d5 commit 3a6d24f
Showing 1 changed file with 64 additions and 90 deletions.
154 changes: 64 additions & 90 deletions src/Lean/Compiler/IR/ExpandResetReuse.lean
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
Authors: Leonardo de Moura, Anton Lorenzen
-/
prelude
import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.NormIds
import Lean.Compiler.IR.FreeVars

namespace Lean.IR.ExpandResetReuse
/-- Mapping from variable to projections -/

/-- Mapping from variable to projections.
We use this in reuse specialization to avoid setting fields that are already set.
-/
abbrev ProjMap := HashMap VarId Expr
namespace CollectProjMap
abbrev Collector := ProjMap → ProjMap
Expand Down Expand Up @@ -39,7 +42,15 @@ structure Context where

abbrev Mask := Array (Option VarId)

/-- Auxiliary function for eraseProjIncFor -/
/-- Auxiliary function for eraseProjIncFor.
Traverse bs left to right to find pairs of
```
let z := proj[i] y
inc z n c
```
If `n=1` remove the `inc` instruction and if `n>1` replace `inc z n c` with `inc z (n-1) c`.
Additionally, we track the variables `z` that have been found in the mask.
-/
partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (keep : Array FnBody) : Array FnBody × Mask :=
let done (_ : Unit) := (bs ++ keep.reverse, mask)
let keepInstr (b : FnBody) := eraseProjIncForAux y bs.pop mask (keep.push b)
Expand Down Expand Up @@ -71,7 +82,7 @@ partial def eraseProjIncForAux (y : VarId) (bs : Array FnBody) (mask : Mask) (ke
| _ => done ()
| _ => done ()

/-- Try to erase `inc` instructions on projections of `y` occurring in the tail of `bs`.
/-- Try to erase one `inc` instruction on projections of `y` occurring in the tail of `bs`.
Return the updated `bs` and a bit mask specifying which `inc`s have been removed. -/
def eraseProjIncFor (n : Nat) (y : VarId) (bs : Array FnBody) : Array FnBody × Mask :=
eraseProjIncForAux y bs (mkArray n none) #[]
Expand All @@ -82,6 +93,8 @@ abbrev M := ReaderT Context (StateM Nat)
def mkFreshJoinPoint : M JoinPointId :=
modifyGet fun n => ({ idx := n }, n + 1)

/-- If the reused cell is unique, we can reuse its memory.
Then we have to manually release all fields that are not live. -/
def releaseUnreadFields (y : VarId) (mask : Mask) : M (FnBody → FnBody) :=
mask.size.foldM (init := id) fun i b =>
match mask.get! i with
Expand All @@ -90,9 +103,6 @@ def releaseUnreadFields (y : VarId) (mask : Mask) : M (FnBody → FnBody) :=
let fld ← mkFresh
pure (FnBody.vdecl fld IRType.object (Expr.proj i y) ∘ (FnBody.dec fld 1 true false ∘ b))

def setFields (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i b => FnBody.set y i (zs.get! i) b

/-- Given `set x[i] := y`, return true iff `y := proj[i] x` -/
def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
match y with
Expand All @@ -102,6 +112,12 @@ def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
| _ => false
| _ => false

/-- Set fields of `y` to `zs`. We avoid assignments that are already set. -/
def setFields (ctx : Context) (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i b =>
if isSelfSet ctx y i (zs.get! i) then b
else FnBody.set y i (zs.get! i) b

/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
def isSelfUSet (ctx : Context) (x : VarId) (i : Nat) (y : VarId) : Bool :=
match ctx.projMap.find? y with
Expand All @@ -114,99 +130,52 @@ def isSelfSSet (ctx : Context) (x : VarId) (n : Nat) (i : Nat) (y : VarId) : Boo
| some (Expr.sproj m j w) => n == m && j == i && w == x
| _ => false

/-- Remove unnecessary `set/uset/sset` operations -/
partial def removeSelfSet (ctx : Context) : FnBody → FnBody
| FnBody.set x i y b =>
if isSelfSet ctx x i y then removeSelfSet ctx b
else FnBody.set x i y (removeSelfSet ctx b)
| FnBody.uset x i y b =>
if isSelfUSet ctx x i y then removeSelfSet ctx b
else FnBody.uset x i y (removeSelfSet ctx b)
| FnBody.sset x n i y t b =>
if isSelfSSet ctx x n i y then removeSelfSet ctx b
else FnBody.sset x n i y t (removeSelfSet ctx b)
| FnBody.setTag x c b => FnBody.setTag x c (removeSelfSet ctx b)
| e => e

/- New version of `mkFastPath` that uses drop-specialisation and reuse-specialisation
instead of expanding the resets to avoid code blowup.
-/

/-- The empty reuse token returned for non-unique cells. -/
def null := Expr.lit (LitVal.num 0)

/- Create a new join point, where the declaration `v` obtains a function that will generate
a jump to the join point with the variable as an argument. We optimize the case, where
the binding is just a return and float it into the declaration. -/
/-- Create a new join point, where the declaration `v` obtains a function that will generate
a jump to the join point with the variable as an argument. We optimize the case, where
the binding is just a return and float it into the declaration. -/
def mkJoin (x : VarId) (t : IRType) (b : FnBody) (v : (VarId → FnBody) → FnBody) : M FnBody :=
match b with
| FnBody.ret _ =>
return v fun z => FnBody.ret (Arg.var z)
| _ => do
let j ← mkFreshJoinPoint
let z ← mkFresh
-- We use the given VarId for the joinpoint, which avoids the need to rename the tokens.
-- This is especially important since otherwise the ProjectionMap would find projections
-- out of the old variables and thus break reuse specialization.
return FnBody.jdecl j #[mkParam z false t]
(b.replaceVar x z)
(v (fun z => mkJmp j #[Arg.var z]))

/- Reuse specialisation -/
def tryReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do
/-- Reuse specialization. -/
def specializeReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do
let ctx ← read
let null? ← mkFresh
let z ← mkFresh
let newAlloc ← mkFresh
mkJoin reused t b fun jmp =>
(FnBody.vdecl null? IRType.uint8 (Expr.isNull token)
(mkIf null?
(FnBody.vdecl z t (Expr.ctor c xs)
(jmp z))
(removeSelfSet ctx
((if u then FnBody.setTag token c.cidx else id)
(setFields token xs
(jmp token))))))

/- Apply reuse specialisation for a reuse instruction -/
partial def reuseToTryReuse (x y : VarId) : FnBody → M FnBody
| FnBody.dec z n c p b =>
if x == z then return FnBody.del y b
else do
let b ← reuseToTryReuse x y b
return FnBody.dec z n c p b
| FnBody.vdecl z t v b =>
match v with
| Expr.reuse w c u zs =>
if x == w then
tryReuse z y c u t zs b
else do
let b ← reuseToTryReuse x y b
return FnBody.vdecl z t v b
| _ => do
let b ← reuseToTryReuse x y b
return FnBody.vdecl z t v b
| FnBody.case tid z zType alts => do
let alts ← alts.mapM fun alt => alt.mmodifyBody (reuseToTryReuse x y)
return FnBody.case tid z zType alts
| FnBody.jdecl j xs v b => do
let v ← reuseToTryReuse x y v
let b ← reuseToTryReuse x y b
return FnBody.jdecl j xs v b
| e =>
if e.isTerminal then return e
else do
let (instr, b) := e.split
let b ← reuseToTryReuse x y b
return instr.setBody b
(FnBody.vdecl newAlloc t (Expr.ctor c xs)
(jmp newAlloc))
((if u then FnBody.setTag token c.cidx else id)
(setFields ctx token xs
(jmp token)))))

/-- Increment all live children and decrement y. -/
def adjustReferenceCountsSlowPath (y : VarId) (mask : Mask) (b : FnBody) :=
let b := FnBody.dec y 1 true false b
mask.foldl (init := b) fun b m => match m with
| some z => FnBody.inc z 1 true false b
| none => b

/- Drop specialisation -/
def tryReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
/- Drop specialization -/
def specializeReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBody := do
let shared? ← mkFresh
let z2 ← mkFresh
let fastPath ← releaseUnreadFields oldAlloc mask
let b ← reuseToTryReuse token token b
mkJoin token IRType.object b fun jmp =>
(FnBody.vdecl shared? IRType.uint8 (Expr.isShared oldAlloc)
(mkIf shared?
Expand All @@ -215,33 +184,38 @@ def tryReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBody := d
(jmp z2)))
(fastPath (jmp oldAlloc))))

mutual
partial def specialize (bs : Array FnBody) (x : VarId) (n : Nat) (y : VarId) (b : FnBody) : M FnBody := do
let (bs, mask) := eraseProjIncFor n y bs
let b ← tryReset x y mask b
searchAndExpand b bs

partial def searchAndExpand : FnBody → Array FnBody → M FnBody
| FnBody.vdecl x _ (Expr.reset n y) b, bs =>
specialize bs x n y b
| FnBody.jdecl j xs v b, bs => do
let v ← searchAndExpand v #[]
let b ← searchAndExpand b #[]
partial def searchAndSpecialize : FnBody → Array FnBody → Array VarId → M FnBody
| FnBody.vdecl x _ (Expr.reset n y) b, bs, tokens => do
let (bs, mask) := eraseProjIncFor n y bs
let b ← searchAndSpecialize b #[] (tokens.push x)
let b ← specializeReset x y mask b
return reshape bs b
| FnBody.vdecl z t (Expr.reuse w c u zs) b, bs, tokens => do
let b ← searchAndSpecialize b #[] tokens
let b ← specializeReuse z w c u t zs b
return reshape bs b
| FnBody.dec z n c p b, bs, tokens =>
if Array.contains tokens z then return FnBody.del z b
else do
let b ← searchAndSpecialize b #[] tokens
return reshape bs (FnBody.dec z n c p b)
| FnBody.jdecl j xs v b, bs, tokens => do
let v ← searchAndSpecialize v #[] tokens
let b ← searchAndSpecialize b #[] tokens
return reshape bs (FnBody.jdecl j xs v b)
| FnBody.case tid x xType alts, bs => do
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndExpand b #[]
| FnBody.case tid x xType alts, bs, tokens => do
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndSpecialize b #[] tokens
return reshape bs (FnBody.case tid x xType alts)
| b, bs =>
| b, bs, tokens =>
if b.isTerminal then return reshape bs b
else searchAndExpand b.body (push bs b)
end
else searchAndSpecialize b.body (push bs b) tokens

def main (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let bNew := (searchAndExpand b #[] { projMap := m }).run' nextIdx
let bNew := (searchAndSpecialize b #[] #[] { projMap := m }).run' nextIdx
d.updateBody! bNew
| d => d

Expand Down

0 comments on commit 3a6d24f

Please sign in to comment.