Skip to content

Commit

Permalink
alternative composition rules for revCDeriv
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Oct 3, 2023
1 parent 4ed9501 commit 83eacad
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
12 changes: 12 additions & 0 deletions SciLean/Core/FunctionPropositions/HasSemiAdjoint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def semiAdjoint (f : X → Y) :=
| isTrue h => Classical.choose h
| isFalse _ => 0

-- Basic identities ------------------------------------------------------------
--------------------------------------------------------------------------------

@[simp]
theorem semiAdjoint_zero
(f : X → Y)
: semiAdjoint K f 0 = 0 := by sorry_proof


def semi_inner_ext (x x' : X)
: (∀ φ, TestFunction φ → ⟪x, φ⟫[K] = ⟪x', φ⟫[K])
Expand All @@ -56,6 +65,9 @@ by
rw[← Classical.choose_spec hf φ y hφ]


-- Lambda calculus rules -------------------------------------------------------
--------------------------------------------------------------------------------

namespace HasSemiAdjoint


Expand Down
81 changes: 65 additions & 16 deletions SciLean/Core/FunctionTransformations/RevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ noncomputable
def scalarGradient
(f : X → K) (x : X) : X := (revCDeriv K f x).2 1


@[ftrans]
theorem semiAdjoint.arg_a3.cderiv_rule
(f : X → Y) (a0 : W → Y) (ha0 : IsDifferentiable K a0)
: cderiv K (fun w => semiAdjoint K f (a0 w))
=
fun w dw =>
let dy := cderiv K a0 w dw
semiAdjoint K f dy :=
by
-- derivative of linear map is the map itself
-- but this needs a bit more careful reasoning because we do not assume
-- (hf : HasSemiAdjoint K f) and realy that `semiAdjoint K f = 0` if `f` does
-- not have adjoint
sorry_proof


namespace revCDeriv


Expand Down Expand Up @@ -122,6 +139,21 @@ by
unfold revCDeriv
funext _; ftrans; ftrans; simp

theorem comp_rule'
(f : Y → Z) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: revCDeriv K (fun x : X => f (g x))
=
fun x =>
let ydg := revCDeriv K g x
let zdf := revCDeriv K (fun x' => f (ydg.1 + semiAdjoint K ydg.2 (x' - x))) x
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv
funext _; simp; ftrans


theorem let_rule
(f : X → Y → Z) (g : X → Y)
Expand All @@ -143,6 +175,23 @@ by
funext _; ftrans; ftrans; simp


theorem let_rule'
(f : X → Y → Z) (g : X → Y)
(hf : HasAdjDiff K (fun (x,y) => f x y)) (hg : HasAdjDiff K g)
: revCDeriv K (fun x : X => f x (g x))
=
fun x =>
let ydg := revCDeriv K g x
let zdf := revCDeriv K (fun x' => f x' (ydg.1 + semiAdjoint K ydg.2 (x' - x))) x
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv
funext _; simp; ftrans



theorem pi_rule
(f : X → (i : ι) → E i) (hf : ∀ i, HasAdjDiff K (f · i))
: (revCDeriv K fun (x : X) (i : ι) => f x i)
Expand Down Expand Up @@ -177,6 +226,22 @@ by
unfold revCDeriv; ftrans; simp
rw[cderiv.arg_dx.semiAdjoint_rule_at K f (cderiv K g x) (g x) (by fprop) (by fprop)]

theorem comp_rule_at'
(f : Y → Z) (g : X → Y) (x : X)
(hf : HasAdjDiffAt K f (g x)) (hg : HasAdjDiffAt K g x)
: revCDeriv K (fun x : X => f (g x)) x
=
let ydg := revCDeriv K g x
let zdf := revCDeriv K (fun x' => f (ydg.1 + semiAdjoint K ydg.2 (x' - x))) x
zdf :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revCDeriv; simp; ftrans; simp
rw[cderiv.arg_dx.semiAdjoint_rule_at K f (cderiv K g x) (g x) (by fprop) (by fprop)]
rw[cderiv.comp_rule_at K f (fun x' => g x + cderiv K g x (x' - x)) x (by simp; fprop) (by sorry_proof)]



theorem let_rule_at
(f : X → Y → Z) (g : X → Y) (x : X)
Expand Down Expand Up @@ -1130,22 +1195,6 @@ end InnerProductSpace
-- semiAdjoint -----------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans]
theorem SciLean.semiAdjoint.arg_a3.cderiv_rule
(f : X → Y) (a0 : W → Y) (ha0 : IsDifferentiable K a0)
: cderiv K (fun w => semiAdjoint K f (a0 w))
=
fun w dw =>
let dy := cderiv K a0 w dw
semiAdjoint K f dy :=
by
-- derivative of linear map is the map itself
-- but this needs a bit more careful reasoning because we do not assume
-- (hf : HasSemiAdjoint K f) and realy that `semiAdjoint K f = 0` if `f` does
-- not have adjoint
sorry_proof


-- this should not apply for `a0 = (fun x => x)`
-- @[ftrans]
theorem SciLean.cderiv.arg_a3.semiAdjoint_rule
Expand Down

0 comments on commit 83eacad

Please sign in to comment.