From a41b32902677a000dc80c5cc5bdea4eeb3e9fa3f Mon Sep 17 00:00:00 2001 From: lecopivo Date: Mon, 13 Nov 2023 10:58:50 -0500 Subject: [PATCH] ite and dite derivation rules --- .../Core/FunctionPropositions/HasAdjDiff.lean | 25 +++++++++ .../IsDifferentiable.lean | 24 +++++++++ .../Core/FunctionTransformations/CDeriv.lean | 31 +++++++++++ .../FunctionTransformations/FwdCDeriv.lean | 30 +++++++++++ .../FunctionTransformations/RevCDeriv.lean | 29 ++++++++++ SciLean/Core/Monads/FwdDerivMonad.lean | 53 ++++++++++++++++++ SciLean/Core/Monads/RevDerivMonad.lean | 54 +++++++++++++++++++ 7 files changed, 246 insertions(+) diff --git a/SciLean/Core/FunctionPropositions/HasAdjDiff.lean b/SciLean/Core/FunctionPropositions/HasAdjDiff.lean index d460e2dc..08961139 100644 --- a/SciLean/Core/FunctionPropositions/HasAdjDiff.lean +++ b/SciLean/Core/FunctionPropositions/HasAdjDiff.lean @@ -369,6 +369,31 @@ by have := fun i => (hf i).1 constructor; fprop; ftrans; fprop +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fprop] +theorem ite.arg_te.HasAdjDiff_rule + (c : Prop) [dec : Decidable c] (t e : X → Y) + (ht : HasAdjDiff K t) (he : HasAdjDiff K e) + : HasAdjDiff K (fun x => ite c (t x) (e x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + +@[fprop] +theorem dite.arg_te.HasAdjDiff_rule + (c : Prop) [dec : Decidable c] + (t : c → X → Y) (e : ¬c → X → Y) + (ht : ∀ h, HasAdjDiff K (t h)) (he : ∀ h, HasAdjDiff K (e h)) + : HasAdjDiff K (fun x => dite c (t · x) (e · x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + -------------------------------------------------------------------------------- diff --git a/SciLean/Core/FunctionPropositions/IsDifferentiable.lean b/SciLean/Core/FunctionPropositions/IsDifferentiable.lean index f68cc3ed..e693c02c 100644 --- a/SciLean/Core/FunctionPropositions/IsDifferentiable.lean +++ b/SciLean/Core/FunctionPropositions/IsDifferentiable.lean @@ -328,6 +328,30 @@ theorem SciLean.EnumType.sum.arg_f.IsDifferentiable_rule by sorry_proof +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fprop] +theorem ite.arg_te.IsDifferentiable_rule + (c : Prop) [dec : Decidable c] (t e : X → Y) + (ht : IsDifferentiable K t) (he : IsDifferentiable K e) + : IsDifferentiable K (fun x => ite c (t x) (e x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + +@[fprop] +theorem dite.arg_te.IsDifferentiable_rule + (c : Prop) [dec : Decidable c] + (t : c → X → Y) (e : ¬c → X → Y) + (ht : ∀ h, IsDifferentiable K (t h)) (he : ∀ h, IsDifferentiable K (e h)) + : IsDifferentiable K (fun x => dite c (t · x) (e · x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + -------------------------------------------------------------------------------- diff --git a/SciLean/Core/FunctionTransformations/CDeriv.lean b/SciLean/Core/FunctionTransformations/CDeriv.lean index 1a75c16d..35b83ee8 100644 --- a/SciLean/Core/FunctionTransformations/CDeriv.lean +++ b/SciLean/Core/FunctionTransformations/CDeriv.lean @@ -632,6 +632,37 @@ by funext x; apply SciLean.EnumType.sum.arg_f.cderiv_rule_at f x (fun i => hf i x) +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[ftrans] +theorem ite.arg_te.cderiv_rule + (c : Prop) [dec : Decidable c] (t e : X → Y) + : cderiv K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (cderiv K t y) (cderiv K e y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + +@[ftrans] +theorem dite.arg_te.cderiv_rule + (c : Prop) [dec : Decidable c] + (t : c → X → Y) (e : ¬c → X → Y) + : cderiv K (fun x => dite c (t · x) (e · x)) + = + fun y => + dite c (fun p => cderiv K (t p) y) + (fun p => cderiv K (e p) y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + + + -------------------------------------------------------------------------------- section InnerProductSpace diff --git a/SciLean/Core/FunctionTransformations/FwdCDeriv.lean b/SciLean/Core/FunctionTransformations/FwdCDeriv.lean index 56385cce..958b182d 100644 --- a/SciLean/Core/FunctionTransformations/FwdCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/FwdCDeriv.lean @@ -556,6 +556,36 @@ by unfold fwdCDeriv; ftrans +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[ftrans] +theorem ite.arg_te.fwdCDeriv_rule + (c : Prop) [dec : Decidable c] (t e : X → Y) + : fwdCDeriv K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (fwdCDeriv K t y) (fwdCDeriv K e y) := +by + induction dec + case isTrue h => ext y; simp[h]; simp[h] + case isFalse h => ext y; simp[h]; simp[h] + +@[ftrans] +theorem dite.arg_te.fwdCDeriv_rule + (c : Prop) [dec : Decidable c] + (t : c → X → Y) (e : ¬c → X → Y) + : fwdCDeriv K (fun x => dite c (t · x) (e · x)) + = + fun y => + dite c (fun p => fwdCDeriv K (t p) y) + (fun p => fwdCDeriv K (e p) y) := +by + induction dec + case isTrue h => ext y; simp[h]; simp[h] + case isFalse h => ext y; simp[h]; simp[h] + + -------------------------------------------------------------------------------- section InnerProductSpace diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index 7bcd9db0..dd335ced 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -1148,6 +1148,35 @@ by sorry_proof +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[ftrans] +theorem ite.arg_te.revCDeriv_rule + (c : Prop) [dec : Decidable c] (t e : X → Y) + : revCDeriv K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (revCDeriv K t y) (revCDeriv K e y) := +by + induction dec + case isTrue h => ext y <;> simp[h] + case isFalse h => ext y <;> simp[h] + +@[ftrans] +theorem dite.arg_te.revCDeriv_rule + (c : Prop) [dec : Decidable c] + (t : c → X → Y) (e : ¬c → X → Y) + : revCDeriv K (fun x => dite c (t · x) (e · x)) + = + fun y => + dite c (fun p => revCDeriv K (t p) y) + (fun p => revCDeriv K (e p) y) := +by + induction dec + case isTrue h => ext y <;> simp[h] + case isFalse h => ext y <;> simp[h] + -------------------------------------------------------------------------------- diff --git a/SciLean/Core/Monads/FwdDerivMonad.lean b/SciLean/Core/Monads/FwdDerivMonad.lean index 6ed189f5..c8ec1a90 100644 --- a/SciLean/Core/Monads/FwdDerivMonad.lean +++ b/SciLean/Core/Monads/FwdDerivMonad.lean @@ -574,3 +574,56 @@ by rw [FwdDerivMonad.fwdDerivM_bind _ _ hf hg] simp [FwdDerivMonad.fwdDerivM_pair a0 ha0] + + +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fprop] +theorem ite.arg_te.IsDifferentiableM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + (ht : IsDifferentiableM K t) (he : IsDifferentiableM K e) + : IsDifferentiableM K (fun x => ite c (t x) (e x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + +@[ftrans] +theorem ite.arg_te.fwdDerivM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + : fwdDerivM K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (fwdDerivM K t y) (fwdDerivM K e y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + + +@[fprop] +theorem dite.arg_te.IsDifferentiableM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + (ht : ∀ h, IsDifferentiableM K (t h)) (he : ∀ h, IsDifferentiableM K (e h)) + : IsDifferentiableM K (fun x => dite c (fun h => t h x) (fun h => e h x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + +@[ftrans] +theorem dite.arg_te.fwdDerivM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + : fwdDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) + = + fun y => + dite c (fun h => fwdDerivM K (t h) y) (fun h => fwdDerivM K (e h) y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] diff --git a/SciLean/Core/Monads/RevDerivMonad.lean b/SciLean/Core/Monads/RevDerivMonad.lean index aa720d29..5d227b6c 100644 --- a/SciLean/Core/Monads/RevDerivMonad.lean +++ b/SciLean/Core/Monads/RevDerivMonad.lean @@ -600,3 +600,57 @@ by rw [RevDerivMonad.revDerivM_bind _ _ hf hg] simp [RevDerivMonad.revDerivM_pair a0 ha0] + + + +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fprop] +theorem ite.arg_te.HasAdjDiffM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + (ht : HasAdjDiffM K t) (he : HasAdjDiffM K e) + : HasAdjDiffM K (fun x => ite c (t x) (e x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + +@[ftrans] +theorem ite.arg_te.revDerivM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + : revDerivM K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (revDerivM K t y) (revDerivM K e y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + + +@[fprop] +theorem dite.arg_te.HasAdjDiffM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + (ht : ∀ h, HasAdjDiffM K (t h)) (he : ∀ h, HasAdjDiffM K (e h)) + : HasAdjDiffM K (fun x => dite c (fun h => t h x) (fun h => e h x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + +@[ftrans] +theorem dite.arg_te.revDerivM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + : revDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) + = + fun y => + dite c (fun h => revDerivM K (t h) y) (fun h => revDerivM K (e h) y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h]