diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index 808e2ad0..30710488 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -1146,6 +1146,20 @@ by sorry_proof +-- this should not apply for `a0 = (fun x => x)` +-- @[ftrans] +theorem SciLean.cderiv.arg_a3.semiAdjoint_rule + (f : X → Y) (x : X) (a0 : W → X) (ha0 : HasSemiAdjoint K a0) + : semiAdjoint K (fun w => cderiv K f x (a0 w)) + = + fun dy => + let dx := semiAdjoint K (cderiv K f x) dy + semiAdjoint K a0 dx := +by + sorry_proof + + +set_option trace.Meta.Tactic.simp.rewrite true in @[ftrans] theorem SciLean.semiAdjoint.arg_a3.revCDeriv_rule (f : X → Y) (a0 : W → Y) (hf : HasSemiAdjoint K f) (ha0 : HasAdjDiff K a0) diff --git a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean index 1797ceb9..b0044e2f 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean @@ -1,10 +1,4 @@ -import SciLean.Core.FunctionPropositions.HasAdjDiffAt -import SciLean.Core.FunctionPropositions.HasAdjDiff - -import SciLean.Core.FunctionTransformations.SemiAdjoint - -import SciLean.Tactic.LetNormalize - +import SciLean.Core.FunctionTransformations.RevCDeriv set_option linter.unusedVariables false @@ -32,7 +26,7 @@ namespace revDerivUpdate -------------------------------------------------------------------------------- variable (X) -theorem id_rule +theorem id_rule : revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' k dx => dx + k • dx') := by unfold revDerivUpdate @@ -44,9 +38,9 @@ theorem const_rule (y : Y) by unfold revDerivUpdate funext _; ftrans; ftrans -variable{X} +variable {X} -variable(E) +variable (E) theorem proj_rule (i : ι) : revDerivUpdate K (fun (x : (i:ι) → E i) => x i) = @@ -75,6 +69,21 @@ by unfold revDerivUpdate funext _; ftrans; ftrans; simp +theorem comp_rule' + (f : Y → Z) (g : X → Y) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : revDerivUpdate K (fun x : X => f (g x)) + = + fun x => + let ydg := revDerivUpdate K g x + let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' -x))) x + zdf := +by + have ⟨_,_⟩ := hf + have ⟨_,_⟩ := hg + unfold revDerivUpdate + funext _; simp; ftrans + theorem let_rule (f : X → Y → Z) (g : X → Y) @@ -99,6 +108,21 @@ by have h : IsLinearMap K (semiAdjoint K (cderiv K g x)) := sorry_proof rw[h.map_smul] +theorem let_rule' + (f : X → Y → Z) (g : X → Y) + (hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g) + : revDerivUpdate K (fun x : X => let y := g x; f x y) + = + fun x => + let ydg := revDerivUpdate K g x + let zdf := revDerivUpdate K (fun x' => f x' (ydg.1 + semiAdjoint K (ydg.2 · 1 0) (x' - x))) x + zdf := +by + have ⟨_,_⟩ := hf + have ⟨_,_⟩ := hg + unfold revDerivUpdate + funext x; simp; ftrans + @[inline] def fun_fold {ι : Type _} [EnumType ι] (f : ι → X → X) (x₀ : X) : X := Id.run do @@ -216,7 +240,7 @@ end SciLean -------------------------------------------------------------------------------- -- Function Rules -------------------------------------------------------------- ---------------------------------------------------------------------------------b +-------------------------------------------------------------------------------- open SciLean @@ -230,7 +254,7 @@ variable {E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)] --- Prod.mk -----------------------------------v--------------------------------- +-- Prod.mk --------------------------------------------------------------------- -------------------------------------------------------------------------------- @[ftrans] @@ -247,7 +271,18 @@ by have ⟨_,_⟩ := hf have ⟨_,_⟩ := hg unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc] + +theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule_simple + : revDerivUpdate K (fun xy : X × Y => (xy.1, xy.2)) + = + fun xy => + (xy, fun (dx',dy') k (dx,dy) => (dx+k•dx', dy+k•dy')) := +by + unfold revDerivUpdate; + funext (x,y); simp + funext (dx',dy') k (dx,dy); ftrans; ftrans; simp + -- Prod.fst -------------------------------------------------------------------- -------------------------------------------------------------------------------- @@ -265,6 +300,17 @@ by unfold revDerivUpdate; ftrans; ftrans; simp +theorem Prod.fst.arg_self.revDerivUpdate_rule_simple + : revDerivUpdate K (fun xy : X×Y => xy.1) + = + fun xy => + (xy.1, fun dx' k (dx,dy) => (dx+k•dx', dy)) := +by + unfold revDerivUpdate; + funext (x,y); simp; ftrans; + funext dx' k (dx,dy); ftrans; ftrans; simp + + -- Prod.snd -------------------------------------------------------------------- -------------------------------------------------------------------------------- @@ -280,6 +326,16 @@ by have ⟨_,_⟩ := hf unfold revDerivUpdate; ftrans; ftrans; simp +theorem Prod.snd.arg_self.revDerivUpdate_simple_rule + : revDerivUpdate K (fun xy : X×Y => xy.2) + = + fun xy => + (xy.2, fun dy' k (dx,dy) => (dx, dy + k•dy')) := +by + unfold revDerivUpdate; + funext (x,y); simp; ftrans; + funext dy' k (dx,dy); ftrans; ftrans; simp + -- Function.comp --------------------------------------------------------------- --------------------------------------------------------------------------------