Skip to content

Commit

Permalink
revDeriv rules for SMul.sMul
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2023
1 parent f99275d commit 5bf0b2c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
71 changes: 71 additions & 0 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1312,3 +1312,74 @@ theorem HMul.hMul.arg_a0a1.revDerivProjUpdate_rule
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]



-- SMul.smul -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans]
theorem HSMul.hSMul.arg_a0a1.revDeriv_rule
{Y : Type} [SemiHilbert K Y]
(f : X → K) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDeriv K fun x => f x • g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 • zdg.1, fun dx' => ydf.2 (inner zdg.1 dx') (conj ydf.1 • zdg.2 dx')) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDeriv; unfold revDerivUpdate; ftrans; ftrans; simp[revDeriv]


@[ftrans]
theorem HSMul.hSMul.arg_a0a1.revDerivUpdate_rule
{Y : Type} [SemiHilbert K Y]
(f : X → K) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivUpdate K fun x => f x • g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 • zdg.1, fun dy dx => ydf.2 (inner zdg.1 dy) (zdg.2 (conj ydf.1•dy) dx)) :=
by
unfold revDerivUpdate;
funext x; ftrans; simp[mul_assoc,add_assoc,revDerivUpdate,revDeriv,smul_push]

@[ftrans]
theorem HSMul.hSMul.arg_a0a1.revDerivProj_rule
{Y Yi : Type} {YI : Yi → Type} [StructLike Y Yi YI] [EnumType Yi]
[SemiHilbert K Y] [∀ i, SemiHilbert K (YI i)] [SemiInnerProductSpaceStruct K Y Yi YI]
(f : X → K) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivProj K fun x => f x • g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivProj K g x
(ydf.1 • zdg.1, fun i (dy : YI i) => ydf.2 (inner (StructLike.proj zdg.1 i) dy) (zdg.2 i (conj ydf.1•dy))) :=
by
unfold revDerivProj
ftrans; simp[revDerivUpdate,smul_push,revDeriv]

@[ftrans]
theorem HSMul.hSMul.arg_a0a1.revDerivProjUpdate_rule
{Y Yi : Type} {YI : Yi → Type} [StructLike Y Yi YI] [EnumType Yi]
[SemiHilbert K Y] [∀ i, SemiHilbert K (YI i)] [SemiInnerProductSpaceStruct K Y Yi YI]
(f : X → K) (g : X → Y)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivProjUpdate K fun x => f x • g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivProjUpdate K g x
(ydf.1 • zdg.1, fun i (dy : YI i) dx => ydf.2 (inner (StructLike.proj zdg.1 i) dy) (zdg.2 i (conj ydf.1•dy) dx)) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]


14 changes: 14 additions & 0 deletions SciLean/Data/StructLike/Algebra.lean
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ attribute [simp] VecStruct.proj_add VecStruct.proj_smul VecStruct.proj_zero
theorem oneHot.arg_xi.neg_pull [StructLike X I XI] [DecidableEq I] [∀ i, Vec K (XI i)] [Vec K X] [VecStruct K X I XI] (i : I) (xi : XI i)
: StructLike.oneHot (X:=X) i (- xi) = - StructLike.oneHot (X:=X) i xi := sorry_proof


@[smul_push]
theorem oneHot.arg_xi.smul_push [StructLike X I XI] [DecidableEq I] [∀ i, Vec K (XI i)] [Vec K X] [VecStruct K X I XI] (i : I) (xi : XI i) (k : K)
: k • StructLike.oneHot (X:=X) i xi = StructLike.oneHot (X:=X) i (k•xi) := sorry_proof


@[simp]
theorem oneHot_zero [StructLike X I XI] [DecidableEq I] [∀ i, Vec K (XI i)] [Vec K X] [VecStruct K X I XI] (i : I)
: StructLike.oneHot (X:=X) i 0 = (0 : X) := sorry_proof
Expand All @@ -47,6 +53,14 @@ class SemiInnerProductSpaceStruct (K X I XI) [StructLike X I XI] [IsROrC K] [Enu
inner_proj : ∀ (x x' : X), ⟪x,x'⟫[K] = ∑ (i : I), ⟪proj x i, proj x' i⟫[K]
testFun_proj : ∀ (x : X), TestFunction x ↔ (∀ i, TestFunction (proj x i))

@[simp]
theorem inner_oneHot_eq_inner_proj [StructLike X I XI] [EnumType I] [∀ i, SemiInnerProductSpace K (XI i)] [SemiInnerProductSpace K X] [SemiInnerProductSpaceStruct K X I XI] (i : I) (xi : XI i) (x : X)
: ⟪x, StructLike.oneHot i xi⟫[K] = ⟪StructLike.proj x i, xi⟫[K] := sorry_proof

@[simp]
theorem inner_oneHot_eq_inner_proj' [StructLike X I XI] [EnumType I] [∀ i, SemiInnerProductSpace K (XI i)] [SemiInnerProductSpace K X] [SemiInnerProductSpaceStruct K X I XI] (i : I) (xi : XI i) (x : X)
: ⟪StructLike.oneHot i xi, x⟫[K] = ⟪xi, StructLike.proj x i⟫[K] := sorry_proof

instance (priority:=low) {X} [Vec K X] : VecStruct K X Unit (fun _ => X) where
proj_add := by simp[StructLike.proj]
proj_smul := by simp[StructLike.proj]
Expand Down

0 comments on commit 5bf0b2c

Please sign in to comment.