Skip to content

Commit

Permalink
started a clean up of revDerivUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 22, 2023
1 parent 7eea93b commit ceac9ca
Showing 1 changed file with 48 additions and 98 deletions.
146 changes: 48 additions & 98 deletions SciLean/Core/FunctionTransformations/RevDerivUpdate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ variable

noncomputable
def revDerivUpdate
(f : X → Y) (x : X) : Y×(Y→K→X→X) :=
(f x, fun dy k dx => dx + k • semiAdjoint K (cderiv K f x) dy)
(f : X → Y) (x : X) : Y×(Y→X→X) :=
let ydf := revCDeriv K f x
(ydf.1, fun dy dx => dx + ydf.2 dy)

-- @[fprop]
-- theorem asdf (f : X → Y) (x : X) (k : K) (hf : HasAdjDiff K f)
Expand All @@ -31,29 +32,29 @@ namespace revDerivUpdate

variable (X)
theorem id_rule
: revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' k dx => dx + k • dx') :=
: revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' dx => dx + dx') :=
by
unfold revDerivUpdate
funext _; ftrans; ftrans
funext _; ftrans


theorem const_rule (y : Y)
: revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ k dx => dx) :=
: revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ dx => dx) :=
by
unfold revDerivUpdate
funext _; ftrans; ftrans
funext _; ftrans
variable {X}

variable (E)
theorem proj_rule (i : ι)
: revDerivUpdate K (fun (x : (i:ι) → E i) => x i)
=
fun x =>
(x i, fun dxi k dx j => if h : i=j then dx j + k • h ▸ dxi else dx j) :=
(x i, fun dxi dx j => if h : i=j then dx j + h ▸ dxi else dx j) :=
by
unfold revDerivUpdate
funext _; ftrans; ftrans;
simp; funext dxi k dx j; simp; sorry_proof
simp; funext dxi dx j; simp; sorry_proof
variable {E}

theorem comp_rule
Expand All @@ -63,32 +64,17 @@ theorem comp_rule
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K f ydg.1
let zdf := revCDeriv K f ydg.1
(zdf.1,
fun dz k dx =>
let dy := zdf.2 dz 1 0
ydg.2 dy k dx) :=
fun dz dx =>
let dy := zdf.2 dz
ydg.2 dy dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext _; ftrans; ftrans; simp

theorem comp_rule'
(f : Y → Z) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => f (g x))
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun x' => f (ydg.1 + semiAdjoint K (ydg.2 · 1 0) x')) 0
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext _; simp; ftrans


theorem let_rule
(f : X → Y → Z) (g : X → Y)
Expand All @@ -99,34 +85,17 @@ theorem let_rule
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun (xy : X×Y) => f xy.1 xy.2) (x,ydg.1)
(zdf.1,
fun dz k dx =>
let dxdy := zdf.2 dz k (dx, 0)
let dx := ydg.2 dxdy.2 1 dxdy.1
fun dz dx =>
let dxdy := zdf.2 dz (dx, 0)
let dx := ydg.2 dxdy.2 dxdy.1
dx) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext x; ftrans; ftrans; simp
funext dz k dx
funext dz dx
simp[add_assoc]
have h : IsLinearMap K (semiAdjoint K (cderiv K g x)) := sorry_proof
rw[h.map_smul]

theorem let_rule'
(f : X → Y → Z) (g : X → Y)
(hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => let y := g x; f x y)
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun x' => f (x + x') (ydg.1 + semiAdjoint K (ydg.2 · 1 0) x')) 0
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
funext x; simp; ftrans


@[inline]
Expand All @@ -144,14 +113,13 @@ theorem pi_rule
let xdf := fun i =>
(revDerivUpdate K fun (x : X) => f x i) x
(fun i => (xdf i).1,
fun dy k dx => fun_fold (fun i => (xdf i).2 (dy i) k) dx)
fun dy dx => fun_fold (fun i => (xdf i).2 (dy i)) dx)
:=
by
have _ := fun i => (hf i).1
have _ := fun i => (hf i).2
unfold revDerivUpdate
funext _; ftrans; ftrans; simp
funext dy dx
sorry_proof


Expand Down Expand Up @@ -271,22 +239,19 @@ theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K f x
((ydg.1,zdf.1), fun dyz k dx => zdf.2 dyz.2 k (ydg.2 dyz.1 k dx)) :=
((ydg.1,zdf.1), fun dyz dx => zdf.2 dyz.2 (ydg.2 dyz.1 dx)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc]
unfold revDerivUpdate; ftrans; simp[add_assoc]


theorem Prod.mk.arg_fstsnd.revDerivUpdate_rule_simple
: revDerivUpdate K (fun xy : X × Y => (xy.1, xy.2))
=
fun xy =>
(xy, fun (dx',dy') k (dx,dy) => (dx+k•dx', dy+k•dy')) :=
(xy, fun (dx',dy') (dx,dy) => (dx+dx', dy+dy')) :=
by
unfold revDerivUpdate;
funext (x,y); simp
funext (dx',dy') k (dx,dy); ftrans; ftrans; simp
ftrans; simp; rfl


-- Prod.fst --------------------------------------------------------------------
Expand All @@ -299,21 +264,21 @@ theorem Prod.fst.arg_self.revDerivUpdate_rule
=
fun x =>
let yzdf := revDerivUpdate K f x
(yzdf.1.1, fun dy k dx => yzdf.2 (dy,0) k dx) :=
(yzdf.1.1, fun dy dx => yzdf.2 (dy,0) dx) :=
by
have ⟨_,_⟩ := hf
unfold revDerivUpdate; ftrans; ftrans; simp
unfold revDerivUpdate;
ftrans


theorem Prod.fst.arg_self.revDerivUpdate_rule_simple
: revDerivUpdate K (fun xy : X×Y => xy.1)
=
fun xy =>
(xy.1, fun dx' k (dx,dy) => (dx+k•dx', dy)) :=
(xy.1, fun dx' (dx,dy) => (dx+dx', dy)) :=
by
unfold revDerivUpdate;
funext (x,y); simp; ftrans;
funext dx' k (dx,dy); ftrans; ftrans; simp
unfold revDerivUpdate;
ftrans; funext x; simp;funext dy (dx₁,dx₂); simp



-- Prod.snd --------------------------------------------------------------------
Expand All @@ -326,7 +291,7 @@ theorem Prod.snd.arg_self.revDerivUpdate_rule
=
fun x =>
let yzdf := revDerivUpdate K f x
(yzdf.1.2, fun dz k dx => yzdf.2 (0,dz) k dx) :=
(yzdf.1.2, fun dz dx => yzdf.2 (0,dz) dx) :=
by
have ⟨_,_⟩ := hf
unfold revDerivUpdate; ftrans; ftrans; simp
Expand All @@ -335,11 +300,10 @@ theorem Prod.snd.arg_self.revDerivUpdate_simple_rule
: revDerivUpdate K (fun xy : X×Y => xy.2)
=
fun xy =>
(xy.2, fun dy' k (dx,dy) => (dx, dy + k•dy')) :=
(xy.2, fun dy' (dx,dy) => (dx, dy + dy')) :=
by
unfold revDerivUpdate;
funext (x,y); simp; ftrans;
funext dy' k (dx,dy); ftrans; ftrans; simp
unfold revDerivUpdate;
ftrans; funext x; simp;funext dy (dx₁,dx₂); simp


-- Function.comp ---------------------------------------------------------------
Expand Down Expand Up @@ -389,11 +353,10 @@ theorem HAdd.hAdd.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let ydg := revDerivUpdate K g x
(ydf.1 + ydg.1, fun dy k dx => ydg.2 dy k (ydf.2 dy k dx)) :=
(ydf.1 + ydg.1, fun dy dx => ydg.2 dy (ydf.2 dy dx)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans; simp[add_assoc]
unfold revDerivUpdate
ftrans; funext x; simp[add_assoc]


-- HSub.hSub -------------------------------------------------------------------
Expand All @@ -407,16 +370,10 @@ theorem HSub.hSub.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let ydg := revDerivUpdate K g x
(ydf.1 - ydg.1, fun dy k dx => ydf.2 dy k (ydg.2 dy (-k) dx)) :=
(ydf.1 - ydg.1, fun dy dx => ydg.2 (-dy) (ydf.2 dy dx)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans; funext x; simp; funext dy k dx;
simp[add_assoc]
rw[sub_eq_add_neg]
rw[smul_add]
rw[smul_neg]
rw[add_comm]
unfold revDerivUpdate
ftrans; funext x; simp; sorry_proof


-- Neg.neg ---------------------------------------------------------------------
Expand All @@ -429,9 +386,9 @@ theorem Neg.neg.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K f x
(-ydf.1, fun dy k dx => ydf.2 dy (-k) dx) :=
(-ydf.1, fun dy dx => ydf.2 (-dy) dx) :=
by
unfold revDerivUpdate; simp; ftrans; ftrans
unfold revDerivUpdate; funext x; ftrans; simp; sorry_proof


-- HMul.hmul -------------------------------------------------------------------
Expand All @@ -447,15 +404,10 @@ theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 * zdg.1, fun dy k dx => ydf.2 dy (k * conj zdg.1) (zdg.2 dy (k * conj ydf.1) dx)) :=
(ydf.1 * zdg.1, fun dy dx => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy) dx)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans;
funext x;
simp
funext dy k dx
simp[smul_smul, add_assoc]
unfold revDerivUpdate;
funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]; sorry_proof



Expand All @@ -472,13 +424,11 @@ theorem HSMul.hSMul.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 • zdg.1, fun dy k dx => ydf.2 (inner zdg.1 dy) k (zdg.2 dy (k*conj ydf.1) dx)) :=

by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; simp; ftrans; ftrans;
funext x; simp; funext dy k dx
simp[add_assoc,smul_smul]
unfold revDerivUpdate;
funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]; sorry_proof


-- HDiv.hDiv -------------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down

0 comments on commit ceac9ca

Please sign in to comment.