diff --git a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean index 414118ad..b2b5365f 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivUpdate.lean @@ -15,8 +15,9 @@ variable noncomputable def revDerivUpdate - (f : X → Y) (x : X) : Y×(Y→K→X→X) := - (f x, fun dy k dx => dx + k • semiAdjoint K (cderiv K f x) dy) + (f : X → Y) (x : X) : Y×(Y→X→X) := + let ydf := revCDeriv K f x + (ydf.1, fun dy dx => dx + ydf.2 dy) -- @[fprop] -- theorem asdf (f : X → Y) (x : X) (k : K) (hf : HasAdjDiff K f) @@ -31,17 +32,17 @@ namespace revDerivUpdate variable (X) theorem id_rule - : revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' k dx => dx + k • dx') := + : revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' dx => dx + dx') := by unfold revDerivUpdate - funext _; ftrans; ftrans + funext _; ftrans theorem const_rule (y : Y) - : revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ k dx => dx) := + : revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ dx => dx) := by unfold revDerivUpdate - funext _; ftrans; ftrans + funext _; ftrans variable {X} variable (E) @@ -49,11 +50,11 @@ theorem proj_rule (i : ι) : revDerivUpdate K (fun (x : (i:ι) → E i) => x i) = fun x => - (x i, fun dxi k dx j => if h : i=j then dx j + k • h ▸ dxi else dx j) := + (x i, fun dxi dx j => if h : i=j then dx j + h ▸ dxi else dx j) := by unfold revDerivUpdate funext _; ftrans; ftrans; - simp; funext dxi k dx j; simp; sorry_proof + simp; funext dxi dx j; simp; sorry_proof variable {E} theorem comp_rule @@ -63,32 +64,17 @@ theorem comp_rule = fun x => let ydg := revDerivUpdate K g x - let zdf := revDerivUpdate K f ydg.1 + let zdf := revCDeriv K f ydg.1 (zdf.1, - fun dz k dx => - let dy := zdf.2 dz 1 0 - ydg.2 dy k dx) := + fun dz dx => + let dy := zdf.2 dz + ydg.2 dy dx) := by have ⟨_,_⟩ := hf have ⟨_,_⟩ := hg unfold revDerivUpdate funext _; ftrans; ftrans; simp -theorem comp_rule' - (f : Y → Z) (g : X → Y) - (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) - : revDerivUpdate K (fun x : X => f (g x)) - = - fun x => - let ydg := revDerivUpdate K g x - let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) x')) 0 - zdf := -by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate - funext _; simp; ftrans - theorem let_rule (f : X → Y → Z) (g : X → Y) @@ -99,34 +85,17 @@ theorem let_rule let ydg := revDerivUpdate K g x let zdf := revDerivUpdate K (fun (xy : X×Y) => f xy.1 xy.2) (x,ydg.1) (zdf.1, - fun dz k dx => - let dxdy := zdf.2 dz k (dx, 0) - let dx := ydg.2 dxdy.2 1 dxdy.1 + fun dz dx => + let dxdy := zdf.2 dz (dx, 0) + let dx := ydg.2 dxdy.2 dxdy.1 dx) := by have ⟨_,_⟩ := hf have ⟨_,_⟩ := hg unfold revDerivUpdate funext x; ftrans; ftrans; simp - funext dz k dx + funext dz dx simp[add_assoc] - have h : IsLinearMap K (semiAdjoint K (cderiv K g x)) := sorry_proof - rw[h.map_smul] - -theorem let_rule' - (f : X → Y → Z) (g : X → Y) - (hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g) - : revDerivUpdate K (fun x : X => let y := g x; f x y) - = - fun x => - let ydg := revDerivUpdate K g x - let zdf := revDerivUpdate K (fun x' => f (x + x') (ydg.1 + semiAdjoint K (ydg.2 · 1 0) x')) 0 - zdf := -by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate - funext x; simp; ftrans @[inline] @@ -144,14 +113,13 @@ theorem pi_rule let xdf := fun i => (revDerivUpdate K fun (x : X) => f x i) x (fun i => (xdf i).1, - fun dy k dx => fun_fold (fun i => (xdf i).2 (dy i) k) dx) + fun dy dx => fun_fold (fun i => (xdf i).2 (dy i)) dx) := by have _ := fun i => (hf i).1 have _ := fun i => (hf i).2 unfold revDerivUpdate funext _; ftrans; ftrans; simp - funext dy dx sorry_proof @@ -271,22 +239,19 @@ theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule fun x => let ydg := revDerivUpdate K g x let zdf := revDerivUpdate K f x - ((ydg.1,zdf.1), fun dyz k dx => zdf.2 dyz.2 k (ydg.2 dyz.1 k dx)) := + ((ydg.1,zdf.1), fun dyz dx => zdf.2 dyz.2 (ydg.2 dyz.1 dx)) := by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc] + unfold revDerivUpdate; ftrans; simp[add_assoc] theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule_simple : revDerivUpdate K (fun xy : X × Y => (xy.1, xy.2)) = fun xy => - (xy, fun (dx',dy') k (dx,dy) => (dx+k•dx', dy+k•dy')) := + (xy, fun (dx',dy') (dx,dy) => (dx+dx', dy+dy')) := by unfold revDerivUpdate; - funext (x,y); simp - funext (dx',dy') k (dx,dy); ftrans; ftrans; simp + ftrans; simp; rfl -- Prod.fst -------------------------------------------------------------------- @@ -299,21 +264,21 @@ theorem Prod.fst.arg_self.revDerivUpdate_rule = fun x => let yzdf := revDerivUpdate K f x - (yzdf.1.1, fun dy k dx => yzdf.2 (dy,0) k dx) := + (yzdf.1.1, fun dy dx => yzdf.2 (dy,0) dx) := by - have ⟨_,_⟩ := hf - unfold revDerivUpdate; ftrans; ftrans; simp + unfold revDerivUpdate; + ftrans theorem Prod.fst.arg_self.revDerivUpdate_rule_simple : revDerivUpdate K (fun xy : X×Y => xy.1) = fun xy => - (xy.1, fun dx' k (dx,dy) => (dx+k•dx', dy)) := + (xy.1, fun dx' (dx,dy) => (dx+dx', dy)) := by - unfold revDerivUpdate; - funext (x,y); simp; ftrans; - funext dx' k (dx,dy); ftrans; ftrans; simp + unfold revDerivUpdate; + ftrans; funext x; simp;funext dy (dx₁,dx₂); simp + -- Prod.snd -------------------------------------------------------------------- @@ -326,7 +291,7 @@ theorem Prod.snd.arg_self.revDerivUpdate_rule = fun x => let yzdf := revDerivUpdate K f x - (yzdf.1.2, fun dz k dx => yzdf.2 (0,dz) k dx) := + (yzdf.1.2, fun dz dx => yzdf.2 (0,dz) dx) := by have ⟨_,_⟩ := hf unfold revDerivUpdate; ftrans; ftrans; simp @@ -335,11 +300,10 @@ theorem Prod.snd.arg_self.revDerivUpdate_simple_rule : revDerivUpdate K (fun xy : X×Y => xy.2) = fun xy => - (xy.2, fun dy' k (dx,dy) => (dx, dy + k•dy')) := + (xy.2, fun dy' (dx,dy) => (dx, dy + dy')) := by - unfold revDerivUpdate; - funext (x,y); simp; ftrans; - funext dy' k (dx,dy); ftrans; ftrans; simp + unfold revDerivUpdate; + ftrans; funext x; simp;funext dy (dx₁,dx₂); simp -- Function.comp --------------------------------------------------------------- @@ -389,11 +353,10 @@ theorem HAdd.hAdd.arg_a0a1.revDerivUpdate_rule fun x => let ydf := revDerivUpdate K f x let ydg := revDerivUpdate K g x - (ydf.1 + ydg.1, fun dy k dx => ydg.2 dy k (ydf.2 dy k dx)) := + (ydf.1 + ydg.1, fun dy dx => ydg.2 dy (ydf.2 dy dx)) := by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc] + unfold revDerivUpdate + ftrans; funext x; simp[add_assoc] -- HSub.hSub ------------------------------------------------------------------- @@ -407,16 +370,10 @@ theorem HSub.hSub.arg_a0a1.revDerivUpdate_rule fun x => let ydf := revDerivUpdate K f x let ydg := revDerivUpdate K g x - (ydf.1 - ydg.1, fun dy k dx => ydf.2 dy k (ydg.2 dy (-k) dx)) := + (ydf.1 - ydg.1, fun dy dx => ydg.2 (-dy) (ydf.2 dy dx)) := by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate; simp; ftrans; ftrans; funext x; simp; funext dy k dx; - simp[add_assoc] - rw[sub_eq_add_neg] - rw[smul_add] - rw[smul_neg] - rw[add_comm] + unfold revDerivUpdate + ftrans; funext x; simp; sorry_proof -- Neg.neg --------------------------------------------------------------------- @@ -429,9 +386,9 @@ theorem Neg.neg.arg_a0.revDerivUpdate_rule = fun x => let ydf := revDerivUpdate K f x - (-ydf.1, fun dy k dx => ydf.2 dy (-k) dx) := + (-ydf.1, fun dy dx => ydf.2 (-dy) dx) := by - unfold revDerivUpdate; simp; ftrans; ftrans + unfold revDerivUpdate; funext x; ftrans; simp; sorry_proof -- HMul.hmul ------------------------------------------------------------------- @@ -447,15 +404,10 @@ theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule fun x => let ydf := revDerivUpdate K f x let zdg := revDerivUpdate K g x - (ydf.1 * zdg.1, fun dy k dx => ydf.2 dy (k * conj zdg.1) (zdg.2 dy (k * conj ydf.1) dx)) := + (ydf.1 * zdg.1, fun dy dx => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy) dx)) := by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate; simp; ftrans; ftrans; - funext x; - simp - funext dy k dx - simp[smul_smul, add_assoc] + unfold revDerivUpdate; + funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]; sorry_proof @@ -472,13 +424,11 @@ theorem HSMul.hSMul.arg_a0a1.revDerivUpdate_rule fun x => let ydf := revDerivUpdate K f x let zdg := revDerivUpdate K g x - (ydf.1 • zdg.1, fun dy k dx => ydf.2 (inner zdg.1 dy) k (zdg.2 dy (k*conj ydf.1) dx)) := + by - have ⟨_,_⟩ := hf - have ⟨_,_⟩ := hg - unfold revDerivUpdate; simp; ftrans; ftrans; - funext x; simp; funext dy k dx - simp[add_assoc,smul_smul] + unfold revDerivUpdate; + funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]; sorry_proof + -- HDiv.hDiv ------------------------------------------------------------------- --------------------------------------------------------------------------------