Skip to content

Commit

Permalink
revDeriv rules for HPow.hPow
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2023
1 parent d535f9c commit ce8dc90
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
1 change: 1 addition & 0 deletions SciLean/Core/FunctionPropositions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ import SciLean.Core.FunctionPropositions.HasSemiAdjoint
import SciLean.Core.FunctionPropositions.IsContinuousLinearMap
import SciLean.Core.FunctionPropositions.IsDifferentiable
import SciLean.Core.FunctionPropositions.IsDifferentiableAt
import SciLean.Core.FunctionPropositions.IsLinearMap
49 changes: 49 additions & 0 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,55 @@ by
-- HPow.hPow -------------------------------------------------------------------
--------------------------------------------------------------------------------

@[ftrans]
def HPow.hPow.arg_a0.revDeriv_rule
(f : X → K) (n : Nat) (hf : HasAdjDiff K f)
: revDeriv K (fun x => f x ^ n)
=
fun x =>
let ydf := revDeriv K f x
(ydf.1 ^ n, fun dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) :=
by
have ⟨_,_⟩ := hf
funext x
unfold revDeriv; simp; funext dx; ftrans; ftrans; simp[smul_push,smul_smul]; ring_nf

@[ftrans]
def HPow.hPow.arg_a0.revDerivUpdate_rule
(f : X → K) (n : Nat) (hf : HasAdjDiff K f)
: revDerivUpdate K (fun x => f x ^ n)
=
fun x =>
let ydf := revDerivUpdate K f x
(ydf.1 ^ n,
fun dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) :=
by
unfold revDerivUpdate
funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]

@[ftrans]
def HPow.hPow.arg_a0.revDerivProj_rule
(f : X → K) (n : Nat) (hf : HasAdjDiff K f)
: revDerivProj K (fun x => f x ^ n)
=
fun x =>
let ydf := revDeriv K f x
(ydf.1 ^ n, fun _ dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) :=
by
unfold revDerivProj; ftrans; simp[StructLike.oneHot,StructLike.make]

@[ftrans]
def HPow.hPow.arg_a0.revDerivProjUpdate_rule
(f : X → K) (n : Nat) (hf : HasAdjDiff K f)
: revDerivProjUpdate K (fun x => f x ^ n)
=
fun x =>
let ydf := revDerivUpdate K f x
(ydf.1 ^ n,
fun _ dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) :=
by
unfold revDerivProjUpdate; ftrans; simp[StructLike.oneHot,StructLike.make,revDerivUpdate]


-- EnumType.sum ----------------------------------------------------------------
--------------------------------------------------------------------------------
Expand Down

0 comments on commit ce8dc90

Please sign in to comment.