Skip to content

Commit

Permalink
lsimp v2 now correctly removes trivial lets
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 1, 2023
1 parent 970c3d6 commit 9fa60a3
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 98 deletions.
8 changes: 4 additions & 4 deletions SciLean/Core/FunctionTransformations/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
unfold revCDeriv
funext _; ftrans; ftrans; simp
funext _; ftrans; ftrans


theorem comp_rule_at
Expand Down Expand Up @@ -346,7 +346,7 @@ by
have _ := fun i => (hg i).1
have _ := fun i => (hg i).2
unfold revCDeriv
funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
simp
sorry_proof

Expand Down Expand Up @@ -395,7 +395,7 @@ by
have _ := fun i => (hg i).1
have _ := fun i => (hg i).2
unfold revCDeriv
funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
simp
sorry_proof

Expand All @@ -418,7 +418,7 @@ by
have _ := fun i => (hg i).1
have _ := fun i => (hg i).2
unfold revCDeriv
funext _; -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues
simp
sorry_proof

Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionTransformations/RevDerivUpdate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
unfold revDerivUpdate
funext _; ftrans; ftrans; simp
funext _; ftrans; ftrans; -- simp
funext dy dx
sorry_proof

Expand Down
10 changes: 5 additions & 5 deletions SciLean/Core/FunctionTransformations/RevFDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ theorem comp_rule
by
unfold revFDeriv
funext _
ftrans; ftrans; simp
ext; simp

ftrans; -- ftrans; simp
-- ext; simp
sorry_proof

theorem let_rule
(f : X → Y → Z) (g : X → Y)
Expand Down Expand Up @@ -127,8 +127,8 @@ theorem comp_rule_at
ydg.2 dy) :=
by
unfold revFDeriv
ftrans; ftrans; simp; ext; simp

ftrans; -- ftrans; simp; ext; simp
sorry_proof

theorem let_rule_at
(f : X → Y → Z) (g : X → Y) (x : X)
Expand Down
10 changes: 5 additions & 5 deletions SciLean/Core/Monads/StateT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ theorem _root_.getThe.arg.revDerivValM_rule
pure ((← getThe S), fun ds => modifyThe S (fun ds' => ds + ds'))) :=
by
simp[getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, setThe, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet]
ftrans; simp
ftrans

-- MonadState.get --------------------------------------------------------------
--------------------------------------------------------------------------------
Expand All @@ -328,7 +328,7 @@ theorem _root_.MonadState.get.arg.revDerivValM_rule
pure ((← get), fun ds => modify (fun ds' => ds + ds'))) :=
by
simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, setThe, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet, modifyGet]
ftrans; simp
ftrans


-- setThe ----------------------------------------------------------------------
Expand Down Expand Up @@ -357,7 +357,7 @@ theorem _root_.setThe.arg_s.revDerivM_rule
pure dx)) :=
by
simp[setThe, set, StateT.set, revDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure]
ftrans; simp
ftrans


-- MonadStateOf.set ------------------------------------------------------------
Expand Down Expand Up @@ -386,7 +386,7 @@ theorem _root_.MonadStateOf.set.arg_a0.revDerivM_rule
pure dx)) :=
by
simp[setThe, set, StateT.set, revDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure, get]
ftrans; simp
ftrans

-- modifyThe ----------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down Expand Up @@ -445,7 +445,7 @@ theorem _root_.modify.arg_f.revDerivM_rule
pure dxs.1)) :=
by
simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, setThe, set, StateT.set, get, pure, StateT.pure, modify]
ftrans; simp
ftrans

end RevDerivMonad

Expand Down
123 changes: 71 additions & 52 deletions SciLean/Tactic/LSimp2/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,20 @@ private def filterLetValues (vals vars : Array Expr) : Array Expr × Array Nat :
is := is.push i
(r, is)

theorem let_simp_congr {α β} (f : α → β) {x x' : α} {y' : β}
theorem let_simp_congr {α β} (f : α → β) (x x' : α) (y' : β)
(hx : x = x') (hf : f x' = y') : f x = y' := by rw[hx,hf]

private def mkLetRemoveCongr (f hx hf : Expr) : MetaM Expr :=
mkAppM ``let_simp_congr #[f,hx,hf]
theorem let_remove_congr {α β} (f : α → β) (x x' : α) (y' : β)
(hx : x = x') (hf : f x' = y') : (let y := x; f y) = y' := by rw[hx,hf]

private def mkLetRemoveCongr (f x x' y' hx hf : Expr) : MetaM Expr :=
mkAppM ``let_simp_congr #[f,x,x',y',hx,hf]

private def mkCast (e : Expr) (E' : Expr) : MetaM Expr := do
let proof ← mkFreshExprMVar (← mkEq (←inferType e) E')
proof.mvarId!.refl
mkAppM ``cast #[proof, e]



partial def simp (e : Expr) : M Result := withIncRecDepth do
Expand Down Expand Up @@ -808,60 +817,70 @@ where

simpLet (e : Expr) : M Result := do
let Expr.letE n t v b _ := e | unreachable!
if (← Simp.getConfig).zeta then
return { expr := b.instantiate1 v }
if (← Simp.getConfig).zeta || ¬b.hasLooseBVars then
-- return { expr := b.instantiate1 v }
return ← simp (b.instantiate1 v)
else
match (← getSimpLetCase n t b) with
| SimpLetCase.dep => return { expr := (← dsimp e) }
| SimpLetCase.nondep =>
let rv ← simp v
letTelescope rv.expr fun fvars v' => do
-- this does not seem to work :(
-- if removeLet v' then
-- let bv := b.instantiate1 v'
-- let rbv ← simp bv
-- let e' ← mkLetFVars fvars rbv.expr
-- return { expr := e', proof? := some (← mkLetRemoveCongr (Expr.lam n t b .default) (← rv.getProof) (← mkLetFVars fvars (← rbv.getProof))) }
-- else
match ← splitByCtors? v' with
| .none =>
withLocalDeclD n t fun x => do
let bx := b.instantiate1 x
let rbx ← simp bx
let hb? ← match rbx.proof? with
| none => pure none
| some h => pure (some (← mkLambdaFVars #[x] h))
let e' ← mkLetFVars fvars (mkLet n t v' (← rbx.expr.abstractM #[x]))
match rv.proof?, hb? with
| none, none => return { expr := e' }
| some h, none => return { expr := e', proof? := some (← mkLetValCongr (← mkLambdaFVars #[x] rbx.expr) h) }
| _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) }
| .some (vs', projs, mk') =>
let names := (Array.range vs'.size).map fun i => n.appendAfter (toString i)
let types ← liftM <| vs'.mapM inferType
withLocalDecls' names .default types fun xs => do
-- let (xs', is) := filterLetValues vs' xs
let bx := b.instantiate1 (mk'.beta xs)
let rbx ← simp bx
let hb? ← match rbx.proof? with
| none => pure none
| some h =>
let h' ←
withLocalDeclD n t fun x => do
mkLambdaFVars #[x] (h.replaceFVars xs (projs.map (fun proj => proj.beta #[x])))
pure (some h')
let e' ←
withLetDecls names vs' fun fvars' =>
mkLetFVars (fvars ++ fvars') (rbx.expr.replaceFVars xs fvars')
match rv.proof?, hb? with
| none, none => return { expr := e' }
| some h, none =>
let b' ←
withLocalDeclD n t fun x => do
mkLambdaFVars #[x] (rbx.expr.replaceFVars xs (projs.map (fun proj => proj.beta #[x])))
return { expr := e', proof? := some (← mkLetValCongr b' h) }
| _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) }
| SimpLetCase.nondepDepVar =>

if removeLet rv.expr then
let e' := b.instantiate1 rv.expr
let proof? ← do
match rv.proof? with
| none => pure none
| some h => pure <| .some (← mkCongrArg (.lam n t b .default) h)
let r : Result := {
expr := e'
proof? := proof?
}
let r' ← simp e'
return ← mkEqTrans r r'

else if rv.expr.isLet then
letTelescope rv.expr fun fvars v' => do
let e' ← mkLetFVars fvars (.letE n t v' b false)
let proof? ← do
match rv.proof? with
| none => pure none
| some h => pure <| .some (← mkLetValCongr (.lam n t b .default) h)
let r : Result := {
expr := e'
proof? := proof?
}
let r' ← simp e'
return ← mkEqTrans r r'

else if let .some (vs, projs, mk) ← splitByCtors? rv.expr then
let names := (Array.range vs.size).map fun i => n.appendAfter (toString i)
let e' ←
withLetDecls names vs fun fvars' =>
mkLetFVars fvars' (b.instantiate1 (mk.beta fvars'))
let r : Result ← do
let b' ←
withLocalDeclD n t fun x => do
mkLambdaFVars #[x] (b.instantiate1 (mk.beta <| projs.map (fun proj => proj.beta #[x])))
let proof ← mkLetValCongr b' (← rv.getProof)
pure (Result.mk (expr := e') (proof? := .some proof) 0)
let r' ← simp e'
return ← mkEqTrans r r'

else
withLocalDeclD n t fun x => do
let bx := b.instantiate1 x
let rbx ← simp bx
let hb? ← match rbx.proof? with
| none => pure none
| some h => pure (some (← mkLambdaFVars #[x] h))
let e' := mkLet n t rv.expr (← rbx.expr.abstractM #[x])
match rv.proof?, hb? with
| none, none => return { expr := e' }
| some h, none => return { expr := e', proof? := some (← mkLetValCongr (← mkLambdaFVars #[x] rbx.expr) h) }
| _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) }

| SimpLetCase.nondepDepVar =>
let v' ← dsimp v
withLocalDeclD n t fun x => do
let bx := b.instantiate1 x
Expand Down
105 changes: 99 additions & 6 deletions SciLean/Tactic/LSimp2/Test.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,108 @@ open SciLean
-- set_option trace.Meta.Tactic.simp.unify true in
-- set_option trace.Meta.Tactic.lsimp.post true in

opaque foo1 {α} (a : α) : α := a

@[simp]
theorem foo1_id {α} (a : α) : foo1 a = a := sorry

set_option trace.Meta.Tactic.simp.rewrite true in
#check
(let a := foo1 (fun x => x) 1
let b := foo1 (fun x => x) (foo1 (foo1 a))
b)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})
lsimp (config := {zeta:=false, singlePass := true})


def ar : Array Nat := #[1,2,3,4,5]


#check
(
let b := 10
id b
)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})



#check
(let a := (hold 1 + 1, hold 10)
let b := hold 2
let c :=
let d := a
id (a.1 + b + d.1 + a.2)
a.1 + b + c + a.2)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})


#check
(
let a :=
let i := (3, 4)
let c :=
let a := hold 1
hold i.1
c
let w := a + a
w)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})


#check
(
let a :=
let i : Fin 5 := ⟨3, by simp⟩
let c :=
let a := hold 1
id (ar[i])
c
let w := a + a
w)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})


#check
(fun x : Nat =>
let a :=
let b :=
let a := 10
((a,20), 30)
let a := hold 10
(a,hold 20)
let i := (3, 4)
let c :=
let a := b.1 + 10
id (1 + 2 + 3 + i.1)
c
let w := a + a
w)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})

#check
(fun x : Nat =>
let a :=
let c :=
let a := 10 + 10
(1 + 2 + 3 + a)
c
let w := a + a
w)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})


#check
(fun x : Nat =>
let a :=
let b :=
let a := hold 10
((a,hold 20), hold 30)
let i : Fin 5 := ⟨3, by simp⟩
let c :=
let a := b.1.2 + 10
Expand All @@ -30,7 +124,7 @@ def ar : Array Nat := #[1,2,3,4,5]
w)
rewrite_by
lsimp (config := {zeta:=false, singlePass := true})
lsimp (config := {zeta:=false, singlePass := true})



def foo :=
Expand All @@ -51,7 +145,7 @@ def foo :=
let w := z + 0
w)
rewrite_by
lsimp (config := {zeta:=false})
lsimp (config := {zeta:=false, singlePass := true})


set_option trace.Meta.Tactic.simp.rewrite true in
Expand Down Expand Up @@ -88,6 +182,5 @@ example
let w := z + y;
w :=
by
conv => lhs; lsimp (config := {zeta := false})

conv => lhs; lsimp (config := {zeta := false, singlePass := true})

Loading

0 comments on commit 9fa60a3

Please sign in to comment.