diff --git a/SciLean/Core/FunctionPropositions.lean b/SciLean/Core/FunctionPropositions.lean index 761a706c..1a8856e9 100644 --- a/SciLean/Core/FunctionPropositions.lean +++ b/SciLean/Core/FunctionPropositions.lean @@ -8,3 +8,4 @@ import SciLean.Core.FunctionPropositions.HasSemiAdjoint import SciLean.Core.FunctionPropositions.IsContinuousLinearMap import SciLean.Core.FunctionPropositions.IsDifferentiable import SciLean.Core.FunctionPropositions.IsDifferentiableAt +import SciLean.Core.FunctionPropositions.IsLinearMap diff --git a/SciLean/Core/FunctionTransformations/RevDeriv.lean b/SciLean/Core/FunctionTransformations/RevDeriv.lean index 1c5225a9..d0dff629 100644 --- a/SciLean/Core/FunctionTransformations/RevDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevDeriv.lean @@ -1461,6 +1461,55 @@ by -- HPow.hPow ------------------------------------------------------------------- -------------------------------------------------------------------------------- +@[ftrans] +def HPow.hPow.arg_a0.revDeriv_rule + (f : X → K) (n : Nat) (hf : HasAdjDiff K f) + : revDeriv K (fun x => f x ^ n) + = + fun x => + let ydf := revDeriv K f x + (ydf.1 ^ n, fun dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) := +by + have ⟨_,_⟩ := hf + funext x + unfold revDeriv; simp; funext dx; ftrans; ftrans; simp[smul_push,smul_smul]; ring_nf + +@[ftrans] +def HPow.hPow.arg_a0.revDerivUpdate_rule + (f : X → K) (n : Nat) (hf : HasAdjDiff K f) + : revDerivUpdate K (fun x => f x ^ n) + = + fun x => + let ydf := revDerivUpdate K f x + (ydf.1 ^ n, + fun dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) := +by + unfold revDerivUpdate + funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc] + +@[ftrans] +def HPow.hPow.arg_a0.revDerivProj_rule + (f : X → K) (n : Nat) (hf : HasAdjDiff K f) + : revDerivProj K (fun x => f x ^ n) + = + fun x => + let ydf := revDeriv K f x + (ydf.1 ^ n, fun _ dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) := +by + unfold revDerivProj; ftrans; simp[StructLike.oneHot,StructLike.make] + +@[ftrans] +def HPow.hPow.arg_a0.revDerivProjUpdate_rule + (f : X → K) (n : Nat) (hf : HasAdjDiff K f) + : revDerivProjUpdate K (fun x => f x ^ n) + = + fun x => + let ydf := revDerivUpdate K f x + (ydf.1 ^ n, + fun _ dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) := +by + unfold revDerivProjUpdate; ftrans; simp[StructLike.oneHot,StructLike.make,revDerivUpdate] + -- EnumType.sum ---------------------------------------------------------------- --------------------------------------------------------------------------------