From d1c5f3b70a9967df9480e88bae3a161b521b42e4 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Thu, 23 Nov 2023 11:24:13 -0500 Subject: [PATCH] lambda calculus rules for revDerivUpdate --- .../FunctionTransformations/RevDeriv.lean | 91 ++++++++++++++++++- 1 file changed, 87 insertions(+), 4 deletions(-) diff --git a/SciLean/Core/FunctionTransformations/RevDeriv.lean b/SciLean/Core/FunctionTransformations/RevDeriv.lean index 861fda40..e0963547 100644 --- a/SciLean/Core/FunctionTransformations/RevDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevDeriv.lean @@ -141,7 +141,6 @@ by unfold revDeriv funext _; ftrans; ftrans - theorem const_rule (y : Y) : revDeriv K (fun _ : X => y) = fun x => (y, fun _ => 0) := by @@ -149,7 +148,7 @@ by funext _; ftrans; ftrans variable{X} -variable(E) +variable(EI) theorem proj_rule (i : I) : revDeriv K (fun (x : (i:I) → EI i) => x i) = @@ -158,8 +157,7 @@ theorem proj_rule (i : I) by unfold revDeriv funext _; ftrans; ftrans -variable {E} - +variable {EI} theorem comp_rule (f : Y → Z) (g : X → Y) @@ -216,9 +214,94 @@ by have _ := fun i => (hf i).2 unfold revDeriv funext _; ftrans; ftrans + sorry_proof end revDeriv + -------------------------------------------------------------------------------- -- Lambda calculus rules for revDerivUpdate ------------------------------------ -------------------------------------------------------------------------------- + +namespace revDerivUpdate + +variable (X) +theorem id_rule + : revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' dx => dx + dx') := +by + unfold revDerivUpdate + simp [revDeriv.id_rule] + + +theorem const_rule (y : Y) + : revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ dx => dx) := +by + unfold revDerivUpdate + simp [revDeriv.const_rule] + +variable {X} + +variable (EI) +theorem proj_rule (i : I) + : revDerivUpdate K (fun (x : (i:I) → EI i) => x i) + = + fun x => + (x i, fun dxi dx j => if h : i=j then dx j + h ▸ dxi else dx j) := +by + unfold revDerivUpdate + simp [revDeriv.proj_rule] + funext _; ftrans; ftrans; + simp; funext dxi dx j; simp; sorry_proof +variable {EI} + +theorem comp_rule + (f : Y → Z) (g : X → Y) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) + : revDerivUpdate K (fun x : X => f (g x)) + = + fun x => + let ydg := revDerivUpdate K g x + let zdf := revDeriv K f ydg.1 + (zdf.1, + fun dz dx => + let dy := zdf.2 dz + ydg.2 dy dx) := +by + unfold revDerivUpdate + simp [revDeriv.comp_rule _ _ _ hf hg] + +theorem let_rule + (f : X → Y → Z) (g : X → Y) + (hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g) + : revDerivUpdate K (fun x : X => let y := g x; f x y) + = + fun x => + let ydg := revDerivUpdate K g x + let zdf := revDerivUpdate K (fun (xy : X×Y) => f xy.1 xy.2) (x,ydg.1) + (zdf.1, + fun dz dx => + let dxdy := zdf.2 dz (dx, 0) + let dx := ydg.2 dxdy.2 dxdy.1 + dx) := +by + unfold revDerivUpdate + simp [revDeriv.let_rule _ _ _ hf hg, revDerivUpdate,add_assoc] + +theorem pi_rule + (f : X → (i : I) → EI i) (hf : ∀ i, HasAdjDiff K (f · i)) + : (revDerivUpdate K fun (x : X) (i : I) => f x i) + = + fun x => + let xdf := revDerivProjUpdate K f x + (fun i => xdf.1 i, + fun dy dx => Id.run do + let mut dx := dx + for i in fullRange I do + dx := xdf.2 ⟨i,()⟩ (dy i) dx + dx) := +by + unfold revDerivUpdate + simp [revDeriv.pi_rule _ _ hf, revDerivUpdate] + sorry_proof + +end revDerivUpdate