Skip to content

Commit

Permalink
lambda calculus rules for revDerivUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 23, 2023
1 parent 7d33f8b commit d1c5f3b
Showing 1 changed file with 87 additions and 4 deletions.
91 changes: 87 additions & 4 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,14 @@ by
unfold revDeriv
funext _; ftrans; ftrans


theorem const_rule (y : Y)
: revDeriv K (fun _ : X => y) = fun x => (y, fun _ => 0) :=
by
unfold revDeriv
funext _; ftrans; ftrans
variable{X}

variable(E)
variable(EI)
theorem proj_rule (i : I)
: revDeriv K (fun (x : (i:I) → EI i) => x i)
=
Expand All @@ -158,8 +157,7 @@ theorem proj_rule (i : I)
by
unfold revDeriv
funext _; ftrans; ftrans
variable {E}

variable {EI}

theorem comp_rule
(f : Y → Z) (g : X → Y)
Expand Down Expand Up @@ -216,9 +214,94 @@ by
have _ := fun i => (hf i).2
unfold revDeriv
funext _; ftrans; ftrans
sorry_proof

end revDeriv


--------------------------------------------------------------------------------
-- Lambda calculus rules for revDerivUpdate ------------------------------------
--------------------------------------------------------------------------------

namespace revDerivUpdate

variable (X)
theorem id_rule
: revDerivUpdate K (fun x : X => x) = fun x => (x, fun dx' dx => dx + dx') :=
by
unfold revDerivUpdate
simp [revDeriv.id_rule]


theorem const_rule (y : Y)
: revDerivUpdate K (fun _ : X => y) = fun x => (y, fun _ dx => dx) :=
by
unfold revDerivUpdate
simp [revDeriv.const_rule]

variable {X}

variable (EI)
theorem proj_rule (i : I)
: revDerivUpdate K (fun (x : (i:I) → EI i) => x i)
=
fun x =>
(x i, fun dxi dx j => if h : i=j then dx j + h ▸ dxi else dx j) :=
by
unfold revDerivUpdate
simp [revDeriv.proj_rule]
funext _; ftrans; ftrans;
simp; funext dxi dx j; simp; sorry_proof
variable {EI}

theorem comp_rule
(f : Y → Z) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => f (g x))
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDeriv K f ydg.1
(zdf.1,
fun dz dx =>
let dy := zdf.2 dz
ydg.2 dy dx) :=
by
unfold revDerivUpdate
simp [revDeriv.comp_rule _ _ _ hf hg]

theorem let_rule
(f : X → Y → Z) (g : X → Y)
(hf : HasAdjDiff K (fun (xy : X×Y) => f xy.1 xy.2)) (hg : HasAdjDiff K g)
: revDerivUpdate K (fun x : X => let y := g x; f x y)
=
fun x =>
let ydg := revDerivUpdate K g x
let zdf := revDerivUpdate K (fun (xy : X×Y) => f xy.1 xy.2) (x,ydg.1)
(zdf.1,
fun dz dx =>
let dxdy := zdf.2 dz (dx, 0)
let dx := ydg.2 dxdy.2 dxdy.1
dx) :=
by
unfold revDerivUpdate
simp [revDeriv.let_rule _ _ _ hf hg, revDerivUpdate,add_assoc]

theorem pi_rule
(f : X → (i : I) → EI i) (hf : ∀ i, HasAdjDiff K (f · i))
: (revDerivUpdate K fun (x : X) (i : I) => f x i)
=
fun x =>
let xdf := revDerivProjUpdate K f x
(fun i => xdf.1 i,
fun dy dx => Id.run do
let mut dx := dx
for i in fullRange I do
dx := xdf.2 ⟨i,()⟩ (dy i) dx
dx) :=
by
unfold revDerivUpdate
simp [revDeriv.pi_rule _ _ hf, revDerivUpdate]
sorry_proof

end revDerivUpdate

0 comments on commit d1c5f3b

Please sign in to comment.