Skip to content

Commit

Permalink
feat: propagate new token names to avoid setting already-set fields
Browse files Browse the repository at this point in the history
  • Loading branch information
anfelor committed Apr 1, 2024
1 parent 3a6d24f commit d2ef785
Showing 1 changed file with 24 additions and 27 deletions.
51 changes: 24 additions & 27 deletions src/Lean/Compiler/IR/ExpandResetReuse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def isSelfSet (ctx : Context) (x : VarId) (i : Nat) (y : Arg) : Bool :=
| _ => false

/-- Set fields of `y` to `zs`. We avoid assignments that are already set. -/
def setFields (ctx : Context) (y : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
def setFields (ctx : Context) (y oldAlloc : VarId) (zs : Array Arg) (b : FnBody) : FnBody :=
zs.size.fold (init := b) fun i b =>
if isSelfSet ctx y i (zs.get! i) then b
if isSelfSet ctx oldAlloc i (zs.get! i) then b
else FnBody.set y i (zs.get! i) b

/-- Given `uset x[i] := y`, return true iff `y := uproj[i] x` -/
Expand Down Expand Up @@ -143,15 +143,12 @@ def mkJoin (x : VarId) (t : IRType) (b : FnBody) (v : (VarId → FnBody) → FnB
| _ => do
let j ← mkFreshJoinPoint
let z ← mkFresh
-- We use the given VarId for the joinpoint, which avoids the need to rename the tokens.
-- This is especially important since otherwise the ProjectionMap would find projections
-- out of the old variables and thus break reuse specialization.
return FnBody.jdecl j #[mkParam z false t]
(b.replaceVar x z)
(v (fun z => mkJmp j #[Arg.var z]))

/-- Reuse specialization. -/
def specializeReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do
def specializeReuse (reused token oldAlloc : VarId) (c : CtorInfo) (u : Bool) (t : IRType) (xs : Array Arg) (b : FnBody) : M FnBody := do
let ctx ← read
let null? ← mkFresh
let newAlloc ← mkFresh
Expand All @@ -161,7 +158,7 @@ def specializeReuse (reused token : VarId) (c : CtorInfo) (u : Bool) (t : IRType
(FnBody.vdecl newAlloc t (Expr.ctor c xs)
(jmp newAlloc))
((if u then FnBody.setTag token c.cidx else id)
(setFields ctx token xs
(setFields ctx token oldAlloc xs
(jmp token)))))

/-- Increment all live children and decrement y. -/
Expand All @@ -184,38 +181,38 @@ def specializeReset (token oldAlloc : VarId) (mask : Mask) (b : FnBody) : M FnBo
(jmp z2)))
(fastPath (jmp oldAlloc))))

partial def searchAndSpecialize : FnBody → Array FnBody → Array VarId → M FnBody
| FnBody.vdecl x _ (Expr.reset n y) b, bs, tokens => do
partial def searchAndSpecialize : FnBody → Array FnBody → Array VarId → HashMap VarId VarId → M FnBody
| FnBody.vdecl x _ (Expr.reset n y) b, bs, tokens, subst => do
let (bs, mask) := eraseProjIncFor n y bs
let b ← searchAndSpecialize b #[] (tokens.push x)
let b ← searchAndSpecialize b #[] (tokens.push x) (subst.insert x y)
let b ← specializeReset x y mask b
return reshape bs b
| FnBody.vdecl z t (Expr.reuse w c u zs) b, bs, tokens => do
let b ← searchAndSpecialize b #[] tokens
let b ← specializeReuse z w c u t zs b
return reshape bs b
| FnBody.dec z n c p b, bs, tokens =>
if Array.contains tokens z then return FnBody.del z b
else do
let b ← searchAndSpecialize b #[] tokens
return reshape bs (FnBody.dec z n c p b)
| FnBody.jdecl j xs v b, bs, tokens => do
let v ← searchAndSpecialize v #[] tokens
let b ← searchAndSpecialize b #[] tokens
-- | FnBody.vdecl z t (Expr.reuse w c u zs) b, bs, tokens, subst => do
-- let b ← searchAndSpecialize b #[] tokens subst
-- let b ← specializeReuse z w (subst.find! w) c u t zs b
-- return reshape bs b
| FnBody.dec z n c p b, bs, tokens, subst =>
if tokens.contains z then return FnBody.del z b
else do
let b ← searchAndSpecialize b #[] tokens subst
return reshape bs (FnBody.dec z n c p b)
| FnBody.jdecl j xs v b, bs, tokens, subst => do
let v ← searchAndSpecialize v #[] tokens subst
let b ← searchAndSpecialize b #[] tokens subst
return reshape bs (FnBody.jdecl j xs v b)
| FnBody.case tid x xType alts, bs, tokens => do
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndSpecialize b #[] tokens
| FnBody.case tid x xType alts, bs, tokens, subst => do
let alts ← alts.mapM fun alt => alt.mmodifyBody fun b => searchAndSpecialize b #[] tokens subst
return reshape bs (FnBody.case tid x xType alts)
| b, bs, tokens =>
| b, bs, tokens, subst =>
if b.isTerminal then return reshape bs b
else searchAndSpecialize b.body (push bs b) tokens
else searchAndSpecialize b.body (push bs b) tokens subst

def main (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. =>
let m := mkProjMap d
let nextIdx := d.maxIndex + 1
let bNew := (searchAndSpecialize b #[] #[] { projMap := m }).run' nextIdx
let bNew := (searchAndSpecialize b #[] #[] HashMap.empty { projMap := m }).run' nextIdx
d.updateBody! bNew
| d => d

Expand Down

0 comments on commit d2ef785

Please sign in to comment.