Skip to content

Commit

Permalink
revDeriv for HMul.hMul
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2023
1 parent 231672c commit f99275d
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,68 @@ theorem Neg.neg.arg_a0.revDerivProjUpdate_rule
(-ydf.1, fun i dy dx => ydf.2 i (-dy) dx) :=
by
unfold revDerivProjUpdate; ftrans


-- HMul.hmul -------------------------------------------------------------------
--------------------------------------------------------------------------------
open ComplexConjugate

@[ftrans]
theorem HMul.hMul.arg_a0a1.revDeriv_rule
(f g : X → K)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDeriv K fun x => f x * g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 * zdg.1, fun dx' => (ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx')))) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; unfold revDeriv; simp; ftrans; ftrans;
simp [smul_push]

@[ftrans]
theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule
(f g : X → K)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivUpdate K fun x => f x * g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 * zdg.1, fun dx' dx => (ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx') dx))) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; unfold revDeriv; simp; ftrans; ftrans;
simp [smul_push,add_assoc]

@[ftrans]
theorem HMul.hMul.arg_a0a1.revDerivProj_rule
(f g : X → K)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivProj K fun x => f x * g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 * zdg.1, fun _ dy => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy))) :=
by
unfold revDerivProj
ftrans; simp[StructLike.oneHot, StructLike.make]

@[ftrans]
theorem HMul.hMul.arg_a0a1.revDerivProjUpdate_rule
(f g : X → K)
(hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: (revDerivProjUpdate K fun x => f x * g x)
=
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 * zdg.1, fun _ dy dx => ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy) dx)) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]

0 comments on commit f99275d

Please sign in to comment.