diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index 69fd49a3..4f1332e5 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -15,8 +15,12 @@ variable (K : Type _) [IsROrC K] {X : Type _} [SemiInnerProductSpace K X] {Y : Type _} [SemiInnerProductSpace K Y] + {Y₁ : Type _} [SemiInnerProductSpace K Y₁] + {Y₂ : Type _} [SemiInnerProductSpace K Y₂] {Z : Type _} [SemiInnerProductSpace K Z] + {W : Type _} [SemiInnerProductSpace K W] {ι : Type _} [EnumType ι] + {κ : Type _} [EnumType κ] {E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)] @@ -213,6 +217,126 @@ by funext _; ftrans; ftrans sorry_proof +-- few more specialized rules for function types + +theorem revCDeriv.pi_comp_rule_simple + (f : Y → Z) (g : X → ι → Y) + (hf : HasAdjDiff K f) + (hg : ∀ j, HasAdjDiff K (g · j)) + : (revCDeriv K fun x i => f (g x i)) + = + fun x => + let ydg := revCDeriv K g x + let zdf := revCDeriv K f + (fun i => (zdf (ydg.1 i)).1, + fun dz => + let dy := fun i => (zdf (ydg.1 i)).2 (dz i) + ydg.2 dy) := +by + have ⟨_,_⟩ := hf + have _ := fun i => (hg i).1 + have _ := fun i => (hg i).2 + unfold revCDeriv + funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + simp + sorry_proof + + +theorem revCDeriv.pi_comp_rule + (f : Y → ι → Z) (g : X → ι → Y) + (hf : ∀ i, HasAdjDiff K (f · i)) + (hg : ∀ j, HasAdjDiff K (g · j)) + : (revCDeriv K fun x i => f (g x i) i) + = + fun x => + let ydg := revCDeriv K g x + let zdf := revCDeriv K (fun (y : ι → Y) i => f (y i) i) ydg.1 + (zdf.1, + fun dz => + let dy := zdf.2 dz + ydg.2 dy) := +by + have _ := fun i => (hf i).1 + have _ := fun i => (hf i).2 + have _ := fun i => (hg i).1 + have _ := fun i => (hg i).2 + unfold revCDeriv + funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + simp + sorry_proof + +theorem revCDeriv.pi_let_rule + (f : X → Y → ι → Z) (g : X → ι → Y) + (hf : ∀ i, HasAdjDiff K (fun xy : X×Y => f xy.1 xy.2 i)) + (hg : ∀ j, HasAdjDiff K (g · j)) + : (revCDeriv K fun x i => let y := g x i; f x y i) + = + fun x => + let ydg := revCDeriv K g x + let zdf := revCDeriv K (fun (xy : X×(ι →Y)) i => f xy.1 (xy.2 i) i) (x,ydg.1) + (zdf.1, + fun dz => + let dxy := zdf.2 dz + dxy.1 + ydg.2 dxy.2) := +by + have _ := fun i => (hf i).1 + have _ := fun i => (hf i).2 + have _ := fun i => (hg i).1 + have _ := fun i => (hg i).2 + unfold revCDeriv + funext _; ftrans -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + simp + sorry_proof + + +theorem revCDeriv.pi_prod_rule + (g₁ : W → ι → Y₁) (g₂ : W → ι → Y₂) + (hg₁ : ∀ i, HasAdjDiff K (g₁ · i)) (hg₂ : ∀ i, HasAdjDiff K (g₂ · i)) + : (revCDeriv K fun w i => (g₁ w i, g₂ w i)) + = + fun w => + let ydg₁ := revCDeriv K g₁ w + let ydg₂ := revCDeriv K g₂ w + (fun i => (ydg₁.1 i, ydg₂.1 i), + fun dy => + ydg₁.2 (fun i => (dy i).1) + ydg₂.2 (fun i => (dy i).2)) := +by + have _ := fun i => (hg₁ i).1 + have _ := fun i => (hg₁ i).2 + have _ := fun i => (hg₂ i).1 + have _ := fun i => (hg₂ i).2 + unfold revCDeriv + funext _; ftrans; -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + simp + sorry_proof + +theorem revCDeriv.pi_uncurry_rule {ι κ} [EnumType ι] [EnumType κ] + (f : X → ι → κ → Y) (hf : ∀ i j, HasAdjDiff K (f · i j)) + : (revCDeriv K fun x i j => f x i j) + = + fun x => + let ydf := revCDeriv K (fun x' (ij : ι×κ) => f x' ij.1 ij.2) x + (fun i j => ydf.1 (i,j), + fun dy => ydf.2 (fun ij : ι×κ => dy ij.1 ij.2)) := +by + have _ := fun i j => (hf i j).1 + have _ := fun i j => (hf i j).2 + unfold revCDeriv + funext _; ftrans; -- ftrans - semiAdjoint.pi_rule fails because of some universe issues + simp + sorry_proof + + +theorem revCDeriv.pi_curry_rule {ι κ} [EnumType ι] [EnumType κ] + (f : X → ι → κ → Y) (hf : ∀ i j, HasAdjDiff K (f · i j)) + : (revCDeriv K fun x (ij : ι×κ) => f x ij.1 ij.2) + = + fun x => + let ydf := revCDeriv K (fun x' i j => f x' i j) x + (fun ij => ydf.1 ij.1 ij.2, + fun dy => ydf.2 (fun i j => dy (i,j))) := +by + sorry_proof -- Register `revCDeriv` as function transformation ------------------------------