Skip to content

Commit

Permalink
few extra rules for revDerivUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 3, 2023
1 parent e56fea2 commit ab4f521
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 12 deletions.
14 changes: 14 additions & 0 deletions SciLean/Core/FunctionTransformations/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,20 @@ by
sorry_proof


-- this should not apply for `a0 = (fun x => x)`
-- @[ftrans]
theorem SciLean.cderiv.arg_a3.semiAdjoint_rule
(f : X → Y) (x : X) (a0 : W → X) (ha0 : HasSemiAdjoint K a0)
: semiAdjoint K (fun w => cderiv K f x (a0 w))
=
fun dy =>
let dx := semiAdjoint K (cderiv K f x) dy
semiAdjoint K a0 dx :=
by
sorry_proof


set_option trace.Meta.Tactic.simp.rewrite true in
@[ftrans]
theorem SciLean.semiAdjoint.arg_a3.revCDeriv_rule
(f : X → Y) (a0 : W → Y) (hf : HasSemiAdjoint K f) (ha0 : HasAdjDiff K a0)
Expand Down
80 changes: 68 additions & 12 deletions SciLean/Core/FunctionTransformations/RevDerivUpdate.lean
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
import SciLean.Core.FunctionPropositions.HasAdjDiffAt
import SciLean.Core.FunctionPropositions.HasAdjDiff

import SciLean.Core.FunctionTransformations.SemiAdjoint

import SciLean.Tactic.LetNormalize

import SciLean.Core.FunctionTransformations.RevCDeriv

set_option linter.unusedVariables false

Expand Down Expand Up @@ -32,7 +26,7 @@ namespace revDerivUpdate
--------------------------------------------------------------------------------

variable (X)
theorem id_rule
theorem id_rule
: revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' k dx => dx + k • dx') :=
by
unfold revDerivUpdate
Expand All @@ -44,9 +38,9 @@ theorem const_rule (y : Y)
by
unfold revDerivUpdate
funext _; ftrans; ftrans
variable{X}
variable {X}

variable(E)
variable (E)
theorem proj_rule (i : ι)
: revDerivUpdate K (fun (x : (i:ι) → E i) => x i)
=
Expand Down Expand Up @@ -75,6 +69,21 @@ by
unfold revDerivUpdate
funext _; ftrans; ftrans; simp

theorem comp_rule'
(f : Y → Z) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => f (g x))
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' -x))) x
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext _; simp; ftrans


theorem let_rule
(f : X → Y → Z) (g : X → Y)
Expand All @@ -99,6 +108,21 @@ by
have h : IsLinearMap K (semiAdjoint K (cderiv K g x)) := sorry_proof
rw[h.map_smul]

theorem let_rule'
(f : X → Y → Z) (g : X → Y)
(hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => let y := g x; f x y)
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun x' => f x' (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' - x))) x
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext x; simp; ftrans


@[inline]
def fun_fold {ι : Type _} [EnumType ι] (f : ι → X → X) (x₀ : X) : X := Id.run do
Expand Down Expand Up @@ -216,7 +240,7 @@ end SciLean

--------------------------------------------------------------------------------
-- Function Rules --------------------------------------------------------------
--------------------------------------------------------------------------------b
--------------------------------------------------------------------------------

open SciLean

Expand All @@ -230,7 +254,7 @@ variable
{E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)]


-- Prod.mk -----------------------------------v---------------------------------
-- Prod.mk ---------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans]
Expand All @@ -247,7 +271,18 @@ by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc]


theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule_simple
: revDerivUpdate K (fun xy : X × Y => (xy.1, xy.2))
=
fun xy =>
(xy, fun (dx',dy') k (dx,dy) => (dx+k•dx', dy+k•dy')) :=
by
unfold revDerivUpdate;
funext (x,y); simp
funext (dx',dy') k (dx,dy); ftrans; ftrans; simp


-- Prod.fst --------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand All @@ -265,6 +300,17 @@ by
unfold revDerivUpdate; ftrans; ftrans; simp


theorem Prod.fst.arg_self.revDerivUpdate_rule_simple
: revDerivUpdate K (fun xy : X×Y => xy.1)
=
fun xy =>
(xy.1, fun dx' k (dx,dy) => (dx+k•dx', dy)) :=
by
unfold revDerivUpdate;
funext (x,y); simp; ftrans;
funext dx' k (dx,dy); ftrans; ftrans; simp


-- Prod.snd --------------------------------------------------------------------
--------------------------------------------------------------------------------

Expand All @@ -280,6 +326,16 @@ by
have ⟨_,_⟩ := hf
unfold revDerivUpdate; ftrans; ftrans; simp

theorem Prod.snd.arg_self.revDerivUpdate_simple_rule
: revDerivUpdate K (fun xy : X×Y => xy.2)
=
fun xy =>
(xy.2, fun dy' k (dx,dy) => (dx, dy + k•dy')) :=
by
unfold revDerivUpdate;
funext (x,y); simp; ftrans;
funext dy' k (dx,dy); ftrans; ftrans; simp


-- Function.comp ---------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down

0 comments on commit ab4f521

Please sign in to comment.