From 6f8fa1a6a80e0ec60fca7f6084f29fde1d087ce3 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Wed, 22 Nov 2023 11:11:27 -0500 Subject: [PATCH] simp lemma `(revCDeriv K f x).1 = f x` --- .../FunctionTransformations/RevCDeriv.lean | 5 +++++ .../FunctionTransformations/RevDerivProj.lean | 18 ++++++++++++++++-- .../RevDerivUpdate.lean | 6 ++++++ SciLean/Core/Monads/Id.lean | 6 +++--- test/basic_gradients.lean | 1 - test/issues/25.lean | 2 +- 6 files changed, 31 insertions(+), 7 deletions(-) diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index dd335ced..6e879ad6 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -44,6 +44,11 @@ noncomputable def scalarGradient (f : X → K) (x : X) : X := (revCDeriv K f x).2 1 +@[simp, ftrans_simp] +theorem revCDeriv_fst (f : X → Y) (x : X) + : (revCDeriv K f x).1 = f x := +by + rfl @[ftrans] theorem semiAdjoint.arg_a3.cderiv_rule diff --git a/SciLean/Core/FunctionTransformations/RevDerivProj.lean b/SciLean/Core/FunctionTransformations/RevDerivProj.lean index f23e8fc5..b5b828a5 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivProj.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivProj.lean @@ -37,6 +37,20 @@ def revDerivProjUpdate (ydf'.1, fun i de dx => dx + ydf'.2 i de) + +@[simp, ftrans_simp] +theorem revDerivProj_fst (f : X → E) (x : X) + : (revDerivProj K f x).1 = f x := +by + rfl + +@[simp, ftrans_simp] +theorem revDerivProjUpdate_fst (f : X → E) (x : X) + : (revDerivProjUpdate K f x).1 = f x := +by + rfl + + -------------------------------------------------------------------------------- @@ -180,10 +194,10 @@ theorem revDerivProjUpdate.comp_rule let zdf' := revDerivProj K f ydg'.1 (zdf'.1, fun i de dx => - ydg'.2 (zdf'.2 i de) 1 dx) := + ydg'.2 (zdf'.2 i de) dx) := by funext x - simp[revDerivProjUpdate,revDerivProj.comp_rule] + simp[revDerivProjUpdate,revDerivProj.comp_rule _ _ _ hf hg] constructor . sorry . diff --git a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean index bf622b54..7358cf72 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean @@ -23,6 +23,12 @@ def revDerivUpdate -- theorem asdf (f : X → Y) (x : X) (k : K) (hf : HasAdjDiff K f) -- : HasSemiAdjoint K (fun y => (revDerivUpdate K f x).2 y k 0) := by unfold revDerivUpdate; ftrans; fprop +@[simp, ftrans_simp] +theorem revDerivProj_fst (f : X → Y) (x : X) + : (revDerivUpdate K f x).1 = f x := +by + rfl + namespace revDerivUpdate diff --git a/SciLean/Core/Monads/Id.lean b/SciLean/Core/Monads/Id.lean index 004ccdae..4ae3f2d0 100644 --- a/SciLean/Core/Monads/Id.lean +++ b/SciLean/Core/Monads/Id.lean @@ -30,9 +30,9 @@ noncomputable instance : RevDerivMonad K Id Id where revDerivM f := revCDeriv K f HasAdjDiffM f := HasAdjDiff K f - revDerivM_pure f := by simp[pure] + revDerivM_pure f := by simp[pure,revCDeriv] revDerivM_bind := by intros; simp; ftrans; rfl - revDerivM_pair y := by intros; simp; ftrans + revDerivM_pair y := by intros; simp; ftrans; simp[revCDeriv] HasAdjDiffM_pure := by simp[pure] HasAdjDiffM_bind := by simp[bind]; fprop HasAdjDiffM_pair y := @@ -131,7 +131,7 @@ theorem Bind.bind.arg_a0a1.revDerivM_rule_on_Id let dx' := ydg'.2 dxy'.2 dxy'.1 + dx') := by - simp[revDerivM]; ftrans + simp[revDerivM]; ftrans; simp[revCDeriv] -- @[ftrans] -- This theorem causes some downstream issue in simp when applying congruence lemmas diff --git a/test/basic_gradients.lean b/test/basic_gradients.lean index 79bca5d4..458d44d5 100644 --- a/test/basic_gradients.lean +++ b/test/basic_gradients.lean @@ -230,7 +230,6 @@ example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5)) ⊞ i => ∑ (j : (Idx' (-5) 5 × Idx' (-5) 5)), w[(j.2,j.1)] * dy[(-j.2.1 +ᵥ i.fst, -j.1.1 +ᵥ i.snd)] := by conv => lhs; unfold gradient; ftrans - sorry_proof diff --git a/test/issues/25.lean b/test/issues/25.lean index fd364b0e..d78bea91 100644 --- a/test/issues/25.lean +++ b/test/issues/25.lean @@ -15,7 +15,7 @@ example = fun x => let ydg := revDerivUpdate K g x - let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' - x))) x + let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 0) (x' - x))) x zdf := by have ⟨_,_⟩ := hf