Skip to content

Commit

Permalink
simp lemma (revCDeriv K f x).1 = f x
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 22, 2023
1 parent 06a84f3 commit 6f8fa1a
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 7 deletions.
5 changes: 5 additions & 0 deletions SciLean/Core/FunctionTransformations/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ noncomputable
def scalarGradient
(f : X → K) (x : X) : X := (revCDeriv K f x).2 1

@[simp, ftrans_simp]
theorem revCDeriv_fst (f : X → Y) (x : X)
: (revCDeriv K f x).1 = f x :=
by
rfl

@[ftrans]
theorem semiAdjoint.arg_a3.cderiv_rule
Expand Down
18 changes: 16 additions & 2 deletions SciLean/Core/FunctionTransformations/RevDerivProj.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def revDerivProjUpdate
(ydf'.1, fun i de dx => dx + ydf'.2 i de)



@[simp, ftrans_simp]
theorem revDerivProj_fst (f : X → E) (x : X)
: (revDerivProj K f x).1 = f x :=
by
rfl

@[simp, ftrans_simp]
theorem revDerivProjUpdate_fst (f : X → E) (x : X)
: (revDerivProjUpdate K f x).1 = f x :=
by
rfl


--------------------------------------------------------------------------------


Expand Down Expand Up @@ -180,10 +194,10 @@ theorem revDerivProjUpdate.comp_rule
let zdf' := revDerivProj K f ydg'.1
(zdf'.1,
fun i de dx =>
ydg'.2 (zdf'.2 i de) 1 dx) :=
ydg'.2 (zdf'.2 i de) dx) :=
by
funext x
simp[revDerivProjUpdate,revDerivProj.comp_rule]
simp[revDerivProjUpdate,revDerivProj.comp_rule _ _ _ hf hg]
constructor
. sorry
.
Expand Down
6 changes: 6 additions & 0 deletions SciLean/Core/FunctionTransformations/RevDerivUpdate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def revDerivUpdate
-- theorem asdf (f : X → Y) (x : X) (k : K) (hf : HasAdjDiff K f)
-- : HasSemiAdjoint K (fun y => (revDerivUpdate K f x).2 y k 0) := by unfold revDerivUpdate; ftrans; fprop

@[simp, ftrans_simp]
theorem revDerivProj_fst (f : X → Y) (x : X)
: (revDerivUpdate K f x).1 = f x :=
by
rfl


namespace revDerivUpdate

Expand Down
6 changes: 3 additions & 3 deletions SciLean/Core/Monads/Id.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ noncomputable
instance : RevDerivMonad K Id Id where
revDerivM f := revCDeriv K f
HasAdjDiffM f := HasAdjDiff K f
revDerivM_pure f := by simp[pure]
revDerivM_pure f := by simp[pure,revCDeriv]
revDerivM_bind := by intros; simp; ftrans; rfl
revDerivM_pair y := by intros; simp; ftrans
revDerivM_pair y := by intros; simp; ftrans; simp[revCDeriv]
HasAdjDiffM_pure := by simp[pure]
HasAdjDiffM_bind := by simp[bind]; fprop
HasAdjDiffM_pair y :=
Expand Down Expand Up @@ -131,7 +131,7 @@ theorem Bind.bind.arg_a0a1.revDerivM_rule_on_Id
let dx' := ydg'.2 dxy'.2
dxy'.1 + dx') :=
by
simp[revDerivM]; ftrans
simp[revDerivM]; ftrans; simp[revCDeriv]

-- @[ftrans]
-- This theorem causes some downstream issue in simp when applying congruence lemmas
Expand Down
1 change: 0 additions & 1 deletion test/basic_gradients.lean
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ example (w : K ^ (Idx' (-5) 5 × Idx' (-5) 5))
⊞ i => ∑ (j : (Idx' (-5) 5 × Idx' (-5) 5)), w[(j.2,j.1)] * dy[(-j.2.1 +ᵥ i.fst, -j.1.1 +ᵥ i.snd)] :=
by
conv => lhs; unfold gradient; ftrans
sorry_proof



2 changes: 1 addition & 1 deletion test/issues/25.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ example
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' - x))) x
let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 0) (x' - x))) x
zdf :=
by
have ⟨_,_⟩ := hf
Expand Down

0 comments on commit 6f8fa1a

Please sign in to comment.