From 9fa60a3038c2a227b29041d30f52dcbf5ec2ddf7 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Sun, 1 Oct 2023 04:42:38 +0200 Subject: [PATCH] lsimp v2 now correctly removes trivial lets --- .../FunctionTransformations/RevCDeriv.lean | 8 +- .../RevDerivUpdate.lean | 2 +- .../FunctionTransformations/RevFDeriv.lean | 10 +- SciLean/Core/Monads/StateT.lean | 10 +- SciLean/Tactic/LSimp2/Main.lean | 123 ++++++++++-------- SciLean/Tactic/LSimp2/Test.lean | 105 ++++++++++++++- test/basic_gradients.lean | 11 +- test/issues/19.lean | 39 +++--- 8 files changed, 210 insertions(+), 98 deletions(-) diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index c31bd59f..7c4ac388 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -159,7 +159,7 @@ by have _ := fun i => (hf i).1 have _ := fun i => (hf i).2 unfold revCDeriv - funext _; ftrans; ftrans; simp + funext _; ftrans; ftrans theorem comp_rule_at @@ -346,7 +346,7 @@ by have _ := fun i => (hg i).1 have _ := fun i => (hg i).2 unfold revCDeriv - funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues simp sorry_proof @@ -395,7 +395,7 @@ by have _ := fun i => (hg i).1 have _ := fun i => (hg i).2 unfold revCDeriv - funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues simp sorry_proof @@ -418,7 +418,7 @@ by have _ := fun i => (hg i).1 have _ := fun i => (hg i).2 unfold revCDeriv - funext _; -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + funext _; -- ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues simp sorry_proof diff --git a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean index 37e70d07..62b83760 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean @@ -124,7 +124,7 @@ by have _ := fun i => (hf i).1 have _ := fun i => (hf i).2 unfold revDerivUpdate - funext _; ftrans; ftrans; simp + funext _; ftrans; ftrans; -- simp funext dy dx sorry_proof diff --git a/SciLean/Core/FunctionTransformations/RevFDeriv.lean b/SciLean/Core/FunctionTransformations/RevFDeriv.lean index 0b38a543..95dcc2e9 100644 --- a/SciLean/Core/FunctionTransformations/RevFDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevFDeriv.lean @@ -81,9 +81,9 @@ theorem comp_rule by unfold revFDeriv funext _ - ftrans; ftrans; simp - ext; simp - + ftrans; -- ftrans; simp + -- ext; simp + sorry_proof theorem let_rule (f : X → Y → Z) (g : X → Y) @@ -127,8 +127,8 @@ theorem comp_rule_at ydg.2 dy) := by unfold revFDeriv - ftrans; ftrans; simp; ext; simp - + ftrans; -- ftrans; simp; ext; simp + sorry_proof theorem let_rule_at (f : X → Y → Z) (g : X → Y) (x : X) diff --git a/SciLean/Core/Monads/StateT.lean b/SciLean/Core/Monads/StateT.lean index bbb978d7..de0d4eb0 100644 --- a/SciLean/Core/Monads/StateT.lean +++ b/SciLean/Core/Monads/StateT.lean @@ -306,7 +306,7 @@ theorem _root_.getThe.arg.revDerivValM_rule pure ((← getThe S), fun ds => modifyThe S (fun ds' => ds + ds'))) := by simp[getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, setThe, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet] - ftrans; simp + ftrans -- MonadState.get -------------------------------------------------------------- -------------------------------------------------------------------------------- @@ -328,7 +328,7 @@ theorem _root_.MonadState.get.arg.revDerivValM_rule pure ((← get), fun ds => modify (fun ds' => ds + ds'))) := by simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,revDerivValM, revDerivM, pure, StateT.pure, bind, StateT.bind, setThe, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet, modifyGet] - ftrans; simp + ftrans -- setThe ---------------------------------------------------------------------- @@ -357,7 +357,7 @@ theorem _root_.setThe.arg_s.revDerivM_rule pure dx)) := by simp[setThe, set, StateT.set, revDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure] - ftrans; simp + ftrans -- MonadStateOf.set ------------------------------------------------------------ @@ -386,7 +386,7 @@ theorem _root_.MonadStateOf.set.arg_a0.revDerivM_rule pure dx)) := by simp[setThe, set, StateT.set, revDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure, get] - ftrans; simp + ftrans -- modifyThe ---------------------------------------------------------------------- -------------------------------------------------------------------------------- @@ -445,7 +445,7 @@ theorem _root_.modify.arg_f.revDerivM_rule pure dxs.1)) := by simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, setThe, set, StateT.set, get, pure, StateT.pure, modify] - ftrans; simp + ftrans end RevDerivMonad diff --git a/SciLean/Tactic/LSimp2/Main.lean b/SciLean/Tactic/LSimp2/Main.lean index 3fa47c47..015a55c8 100644 --- a/SciLean/Tactic/LSimp2/Main.lean +++ b/SciLean/Tactic/LSimp2/Main.lean @@ -373,11 +373,20 @@ private def filterLetValues (vals vars : Array Expr) : Array Expr × Array Nat : is := is.push i (r, is) -theorem let_simp_congr {α β} (f : α → β) {x x' : α} {y' : β} +theorem let_simp_congr {α β} (f : α → β) (x x' : α) (y' : β) (hx : x = x') (hf : f x' = y') : f x = y' := by rw[hx,hf] -private def mkLetRemoveCongr (f hx hf : Expr) : MetaM Expr := - mkAppM ``let_simp_congr #[f,hx,hf] +theorem let_remove_congr {α β} (f : α → β) (x x' : α) (y' : β) + (hx : x = x') (hf : f x' = y') : (let y := x; f y) = y' := by rw[hx,hf] + +private def mkLetRemoveCongr (f x x' y' hx hf : Expr) : MetaM Expr := + mkAppM ``let_simp_congr #[f,x,x',y',hx,hf] + +private def mkCast (e : Expr) (E' : Expr) : MetaM Expr := do + let proof ← mkFreshExprMVar (← mkEq (←inferType e) E') + proof.mvarId!.refl + mkAppM ``cast #[proof, e] + partial def simp (e : Expr) : M Result := withIncRecDepth do @@ -808,60 +817,70 @@ where simpLet (e : Expr) : M Result := do let Expr.letE n t v b _ := e | unreachable! - if (← Simp.getConfig).zeta then - return { expr := b.instantiate1 v } + if (← Simp.getConfig).zeta || ¬b.hasLooseBVars then + -- return { expr := b.instantiate1 v } + return ← simp (b.instantiate1 v) else match (← getSimpLetCase n t b) with | SimpLetCase.dep => return { expr := (← dsimp e) } | SimpLetCase.nondep => let rv ← simp v - letTelescope rv.expr fun fvars v' => do - -- this does not seem to work :( - -- if removeLet v' then - -- let bv := b.instantiate1 v' - -- let rbv ← simp bv - -- let e' ← mkLetFVars fvars rbv.expr - -- return { expr := e', proof? := some (← mkLetRemoveCongr (Expr.lam n t b .default) (← rv.getProof) (← mkLetFVars fvars (← rbv.getProof))) } - -- else - match ← splitByCtors? v' with - | .none => - withLocalDeclD n t fun x => do - let bx := b.instantiate1 x - let rbx ← simp bx - let hb? ← match rbx.proof? with - | none => pure none - | some h => pure (some (← mkLambdaFVars #[x] h)) - let e' ← mkLetFVars fvars (mkLet n t v' (← rbx.expr.abstractM #[x])) - match rv.proof?, hb? with - | none, none => return { expr := e' } - | some h, none => return { expr := e', proof? := some (← mkLetValCongr (← mkLambdaFVars #[x] rbx.expr) h) } - | _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) } - | .some (vs', projs, mk') => - let names := (Array.range vs'.size).map fun i => n.appendAfter (toString i) - let types ← liftM <| vs'.mapM inferType - withLocalDecls' names .default types fun xs => do - -- let (xs', is) := filterLetValues vs' xs - let bx := b.instantiate1 (mk'.beta xs) - let rbx ← simp bx - let hb? ← match rbx.proof? with - | none => pure none - | some h => - let h' ← - withLocalDeclD n t fun x => do - mkLambdaFVars #[x] (h.replaceFVars xs (projs.map (fun proj => proj.beta #[x]))) - pure (some h') - let e' ← - withLetDecls names vs' fun fvars' => - mkLetFVars (fvars ++ fvars') (rbx.expr.replaceFVars xs fvars') - match rv.proof?, hb? with - | none, none => return { expr := e' } - | some h, none => - let b' ← - withLocalDeclD n t fun x => do - mkLambdaFVars #[x] (rbx.expr.replaceFVars xs (projs.map (fun proj => proj.beta #[x]))) - return { expr := e', proof? := some (← mkLetValCongr b' h) } - | _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) } - | SimpLetCase.nondepDepVar => + + if removeLet rv.expr then + let e' := b.instantiate1 rv.expr + let proof? ← do + match rv.proof? with + | none => pure none + | some h => pure <| .some (← mkCongrArg (.lam n t b .default) h) + let r : Result := { + expr := e' + proof? := proof? + } + let r' ← simp e' + return ← mkEqTrans r r' + + else if rv.expr.isLet then + letTelescope rv.expr fun fvars v' => do + let e' ← mkLetFVars fvars (.letE n t v' b false) + let proof? ← do + match rv.proof? with + | none => pure none + | some h => pure <| .some (← mkLetValCongr (.lam n t b .default) h) + let r : Result := { + expr := e' + proof? := proof? + } + let r' ← simp e' + return ← mkEqTrans r r' + + else if let .some (vs, projs, mk) ← splitByCtors? rv.expr then + let names := (Array.range vs.size).map fun i => n.appendAfter (toString i) + let e' ← + withLetDecls names vs fun fvars' => + mkLetFVars fvars' (b.instantiate1 (mk.beta fvars')) + let r : Result ← do + let b' ← + withLocalDeclD n t fun x => do + mkLambdaFVars #[x] (b.instantiate1 (mk.beta <| projs.map (fun proj => proj.beta #[x]))) + let proof ← mkLetValCongr b' (← rv.getProof) + pure (Result.mk (expr := e') (proof? := .some proof) 0) + let r' ← simp e' + return ← mkEqTrans r r' + + else + withLocalDeclD n t fun x => do + let bx := b.instantiate1 x + let rbx ← simp bx + let hb? ← match rbx.proof? with + | none => pure none + | some h => pure (some (← mkLambdaFVars #[x] h)) + let e' := mkLet n t rv.expr (← rbx.expr.abstractM #[x]) + match rv.proof?, hb? with + | none, none => return { expr := e' } + | some h, none => return { expr := e', proof? := some (← mkLetValCongr (← mkLambdaFVars #[x] rbx.expr) h) } + | _, some h => return { expr := e', proof? := some (← mkLetCongr (← rv.getProof) h) } + + | SimpLetCase.nondepDepVar => let v' ← dsimp v withLocalDeclD n t fun x => do let bx := b.instantiate1 x diff --git a/SciLean/Tactic/LSimp2/Test.lean b/SciLean/Tactic/LSimp2/Test.lean index 37ff2258..b8950958 100644 --- a/SciLean/Tactic/LSimp2/Test.lean +++ b/SciLean/Tactic/LSimp2/Test.lean @@ -8,14 +8,108 @@ open SciLean -- set_option trace.Meta.Tactic.simp.unify true in -- set_option trace.Meta.Tactic.lsimp.post true in +opaque foo1 {α} (a : α) : α := a + +@[simp] +theorem foo1_id {α} (a : α) : foo1 a = a := sorry + +set_option trace.Meta.Tactic.simp.rewrite true in +#check + (let a := foo1 (fun x => x) 1 + let b := foo1 (fun x => x) (foo1 (foo1 a)) + b) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + lsimp (config := {zeta:=false, singlePass := true}) + + def ar : Array Nat := #[1,2,3,4,5] + +#check + ( + let b := 10 + id b + ) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + + + +#check + (let a := (hold 1 + 1, hold 10) + let b := hold 2 + let c := + let d := a + id (a.1 + b + d.1 + a.2) + a.1 + b + c + a.2) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + + +#check + ( + let a := + let i := (3, 4) + let c := + let a := hold 1 + hold i.1 + c + let w := a + a + w) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + + +#check + ( + let a := + let i : Fin 5 := ⟨3, by simp⟩ + let c := + let a := hold 1 + id (ar[i]) + c + let w := a + a + w) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + + #check (fun x : Nat => let a := let b := - let a := 10 - ((a,20), 30) + let a := hold 10 + (a,hold 20) + let i := (3, 4) + let c := + let a := b.1 + 10 + id (1 + 2 + 3 + i.1) + c + let w := a + a + w) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + +#check + (fun x : Nat => + let a := + let c := + let a := 10 + 10 + (1 + 2 + 3 + a) + c + let w := a + a + w) + rewrite_by + lsimp (config := {zeta:=false, singlePass := true}) + + +#check + (fun x : Nat => + let a := + let b := + let a := hold 10 + ((a,hold 20), hold 30) let i : Fin 5 := ⟨3, by simp⟩ let c := let a := b.1.2 + 10 @@ -30,7 +124,7 @@ def ar : Array Nat := #[1,2,3,4,5] w) rewrite_by lsimp (config := {zeta:=false, singlePass := true}) - lsimp (config := {zeta:=false, singlePass := true}) + def foo := @@ -51,7 +145,7 @@ def foo := let w := z + 0 w) rewrite_by - lsimp (config := {zeta:=false}) + lsimp (config := {zeta:=false, singlePass := true}) set_option trace.Meta.Tactic.simp.rewrite true in @@ -88,6 +182,5 @@ example let w := z + y; w := by - conv => lhs; lsimp (config := {zeta := false}) - + conv => lhs; lsimp (config := {zeta := false, singlePass := true}) diff --git a/test/basic_gradients.lean b/test/basic_gradients.lean index d0e168ee..72a3a678 100644 --- a/test/basic_gradients.lean +++ b/test/basic_gradients.lean @@ -73,7 +73,7 @@ example = fun x => ⊞ _ => (1:K) := by - (conv => lhs; unfold scalarGradient; ftrans only; autodiff) + (conv => lhs; autodiff) example : (∇ (x : Fin 10 → K), ∑ i, ‖x i‖₂²) @@ -204,7 +204,7 @@ example (w : Idx' (-5) 5 → K) = fun _x dy i => ∑ (j : Idx' (-5) 5), w j * dy (-j.1 +ᵥ i) j := by - conv => lhs; autodiff; autodiff + conv => lhs; autodiff example (w : Idx' (-5) 5 → K) @@ -212,7 +212,7 @@ example (w : Idx' (-5) 5 → K) = fun _x dy i => ∑ (j : Idx' (-5) 5), w j * dy (-j.1 +ᵥ i) := by - conv => lhs; autodiff; autodiff + conv => lhs; autodiff example (w : K ^ Idx' (-5) 5) @@ -220,7 +220,7 @@ example (w : K ^ Idx' (-5) 5) = fun _x dy => ⊞ i => ∑ (j : Idx' (-5) 5), w[j] * dy[-j.1 +ᵥ i] := by - conv => lhs; autodiff; autodiff; simp + conv => lhs; autodiff example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5)) @@ -230,8 +230,7 @@ example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5)) -- ⊞ i => ∑ j, w[j] * dy[(-j.fst.1 +ᵥ i.fst, -j.snd.1 +ᵥ i.snd)] := ⊞ i => ∑ (j : (Idx' (-5) 5 × Idx' (-5) 5)), w[(j.2,j.1)] * dy[(-j.2.1 +ᵥ i.fst, -j.1.1 +ᵥ i.snd)] := by - conv => - lhs; autodiff; autodiff; simp + conv => lhs; unfold gradient; ftrans; simp diff --git a/test/issues/19.lean b/test/issues/19.lean index 8b5563eb..96ea0d46 100644 --- a/test/issues/19.lean +++ b/test/issues/19.lean @@ -13,23 +13,24 @@ set_default_scalar K variable (f : X → Nat → Y) (x : X) -/-- -info: let ydf := fun i => ∂ (x':=x), f x' i; -ydf 0 : X → Y --/ -#guard_msgs in -#check - (let ydf := fun i => ∂ x':=x, f x' i; (ydf 0)) - rewrite_by - ftrans only - -/-- -info: let f := fun i => i; -f 0 : ℕ --/ -#guard_msgs in -#check - (let f := fun i : Nat => i ; (f 0)) - rewrite_by - ftrans only +-- TODO: add option to lsimp not to destroy let bindings with lambdas +-- /-- +-- info: let ydf := fun i => ∂ (x':=x), f x' i; +-- ydf 0 : X → Y +-- -/ +-- #guard_msgs in +-- #check +-- (let ydf := fun i => ∂ x':=x, f x' i; (ydf 0)) +-- rewrite_by +-- ftrans only + +-- /-- +-- info: let f := fun i => i; +-- f 0 : ℕ +-- -/ +-- #guard_msgs in +-- #check +-- (let f := fun i : Nat => i ; (f 0)) +-- rewrite_by +-- ftrans only