diff --git a/SciLean.lean b/SciLean.lean index 3927dd3d..51a4c9a2 100644 --- a/SciLean.lean +++ b/SciLean.lean @@ -16,9 +16,10 @@ import SciLean.Analysis.Calculus.FwdFDeriv -- import SciLean.Analysis.Calculus.HasParamDerivWithDisc.HasParamFwdFDerivWithDisc -- import SciLean.Analysis.Calculus.HasParamDerivWithDisc.HasParamRevFDerivWithDisc import SciLean.Analysis.Calculus.Jacobian -import SciLean.Analysis.Calculus.Monad.FwdCDerivMonad +import SciLean.Analysis.Calculus.Monad.DifferentiableMonad +import SciLean.Analysis.Calculus.Monad.FwdFDerivMonad import SciLean.Analysis.Calculus.Monad.Id -import SciLean.Analysis.Calculus.Monad.RevCDerivMonad +import SciLean.Analysis.Calculus.Monad.RevFDerivMonad import SciLean.Analysis.Calculus.Monad.StateT import SciLean.Analysis.Calculus.Notation.Deriv import SciLean.Analysis.Calculus.Notation.FwdDeriv diff --git a/SciLean/Analysis/AdjointSpace/Basic.lean b/SciLean/Analysis/AdjointSpace/Basic.lean index 4813c542..932231e5 100644 --- a/SciLean/Analysis/AdjointSpace/Basic.lean +++ b/SciLean/Analysis/AdjointSpace/Basic.lean @@ -263,6 +263,18 @@ instance : AdjointSpace 𝕜 𝕜 where add_left := by simp[add_mul] smul_left := by simp[mul_assoc] +instance : Inner 𝕜 Unit where + inner _ _ := 0 + +instance : AdjointSpace 𝕜 Unit where + inner_top_equiv_norm := by + apply Exists.intro 1 + apply Exists.intro 1 + simp[Inner.inner] + conj_symm := by simp[Inner.inner] + add_left := by simp[Inner.inner] + smul_left := by simp[Inner.inner] + instance : AdjointSpace 𝕜 (X×Y) where inner := fun (x,y) (x',y') => ⟪x,x'⟫_𝕜 + ⟪y,y'⟫_𝕜 inner_top_equiv_norm := by diff --git a/SciLean/Analysis/Calculus/Monad/DifferentiableMonad.lean b/SciLean/Analysis/Calculus/Monad/DifferentiableMonad.lean new file mode 100644 index 00000000..97dee40b --- /dev/null +++ b/SciLean/Analysis/Calculus/Monad/DifferentiableMonad.lean @@ -0,0 +1,208 @@ +import SciLean.Analysis.Calculus.FwdFDeriv + +namespace SciLean + + +/-- `DifferentiableMonad K m` states that the monad `m` has the notion of differentiability. +The rought idea is that if the monad `m` stores some state `S` then a function `(f : X → m Y)` +should be also differentiable w.r.t. to the state `S`. + +This class provide proposition `DifferentiableM K f` which is monadic generalization of +differentiability. + +For `StateM S` the `DifferentiableM` is: +``` + DifferentiableM K f + = + Differentiable K (fun (x,s) => f x s) +``` +-/ +class DifferentiableMonad (K : Type) [RCLike K] (m : Type → Type) [Monad m] where + /-- Differentiability of monatic functions. + + For state monad, `m = StateM S`, this predicate says that the function is also differentiable + w.r.t. to the state variable. + ``` + DifferentiableM K f + = + Differentiable K (fun (x,s) => f x s) + ``` + -/ + DifferentiableM {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → m Y) : Prop + + /-- Monadic differentiable pure function is differentiable. -/ + DifferentiableM_pure {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → Y) (hf : Differentiable K f) : + DifferentiableM (fun x : X => pure (f x)) + + /-- Composition of monadic differentiable functions is monadic differentiable. -/ + DifferentiableM_bind {X Y Z : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + [NormedAddCommGroup Z] [NormedSpace K Z] + (f : Y → m Z) (g : X → m Y) + (hf : DifferentiableM f) (hg : DifferentiableM g) : + DifferentiableM (fun x => g x >>= f) + + /-- Theorem allowing us to differentiate let bindings. + + Note: The role of this is still not completely clear to us. Is this really independent of the + previous two requirements? -/ + DifferentiableM_pair {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → m Y) (hf : DifferentiableM f) : + DifferentiableM (fun x => do let y ← f x; pure (x,y)) + + +export DifferentiableMonad (DifferentiableM) + +attribute [fun_prop] DifferentiableM + +set_option deprecated.oldSectionVars true + +variable + (K : Type) [RCLike K] + {m : Type → Type} [Monad m] [DifferentiableMonad K m] + [LawfulMonad m] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] + +open DifferentiableMonad + +/-- Monadic differentiable value. For example, in case of state monad the value `x : StateM S X` +is a function in `S` and it makes sense to ask about differentiability. -/ +def DifferentiableValM (x : m X) : Prop := + DifferentiableM K (fun _ : Unit => x) + + +-------------------------------------------------------------------------------- +-- DifferentiableM ----------------------------------------------------------- +-------------------------------------------------------------------------------- +namespace DifferentiableM + +@[fun_prop] +theorem pure_rule + : DifferentiableM (m:=m) K (fun x : X => pure x) := +by + apply DifferentiableM_pure + fun_prop + +@[fun_prop] +theorem const_rule (y : m Y) (hy : DifferentiableValM K y) + : DifferentiableM K (fun _ : X => y) := +by + have h : (fun _ : X => y) + = + fun _ : X => pure () >>= fun _ => y := by simp + rw[h] + apply DifferentiableM_bind + apply hy + apply DifferentiableM_pure + fun_prop + +@[fun_prop] +theorem comp_rule + (f : Y → m Z) (g : X → Y) + (hf : DifferentiableM K f) (hg : Differentiable K g) + : DifferentiableM K (fun x => f (g x)) := +by + rw[show ((fun x => f (g x)) + = + fun x => pure (g x) >>= f) by simp] + apply DifferentiableM_bind _ _ hf + apply DifferentiableM_pure g hg + +end DifferentiableM + +end SciLean + + + +-------------------------------------------------------------------------------- + +section CoreFunctionProperties + +open SciLean + +set_option deprecated.oldSectionVars true + +variable + (K : Type) [RCLike K] + {m } [Monad m] [DifferentiableMonad K m] + [LawfulMonad m] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] + {E : ι → Type} [∀ i, Vec K (E i)] + + +-------------------------------------------------------------------------------- +-- Pure.pure ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_prop] +theorem Pure.pure.arg_a0.DifferentiableM_rule + (a0 : X → Y) + (ha0 : Differentiable K a0) + : DifferentiableM K (fun x => Pure.pure (f:=m) (a0 x)) := +by + apply DifferentiableMonad.DifferentiableM_pure a0 ha0 + +@[simp, simp_core] +theorem Pure.pure.arg.DifferentiableValM_rule (x : X) + : DifferentiableValM K (pure (f:=m) x) := +by + unfold DifferentiableValM + apply DifferentiableMonad.DifferentiableM_pure + fun_prop + + +-------------------------------------------------------------------------------- +-- Bind.bind ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_prop] +theorem Bind.bind.arg_a0a1.DifferentiableM_rule + (a0 : X → m Y) (a1 : X → Y → m Z) + (ha0 : DifferentiableM K a0) (ha1 : DifferentiableM K (fun (xy : X×Y) => a1 xy.1 xy.2)) + : DifferentiableM K (fun x => Bind.bind (a0 x) (a1 x)) := +by + let g := fun x => do + let y ← a0 x + pure (x,y) + let f := fun xy : X×Y => a1 xy.1 xy.2 + + rw[show (fun x => Bind.bind (a0 x) (a1 x)) + = + fun x => g x >>= f by simp[f,g]] + + have hg : DifferentiableM K (fun x => do let y ← a0 x; pure (x,y)) := + by apply DifferentiableMonad.DifferentiableM_pair a0 ha0 + have hf : DifferentiableM K f := by simp[f]; fun_prop + + apply DifferentiableMonad.DifferentiableM_bind _ _ hf hg + + +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_prop] +theorem ite.arg_te.DifferentiableM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + (ht : DifferentiableM K t) (he : DifferentiableM K e) + : DifferentiableM K (fun x => ite c (t x) (e x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] + + +@[fun_prop] +theorem dite.arg_te.DifferentiableM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + (ht : ∀ h, DifferentiableM K (t h)) (he : ∀ h, DifferentiableM K (e h)) + : DifferentiableM K (fun x => dite c (fun h => t h x) (fun h => e h x)) := +by + induction dec + case isTrue h => simp[ht,h] + case isFalse h => simp[he,h] diff --git a/SciLean/Analysis/Calculus/Monad/FwdCDerivMonad.lean b/SciLean/Analysis/Calculus/Monad/FwdCDerivMonad.lean deleted file mode 100644 index f35c246b..00000000 --- a/SciLean/Analysis/Calculus/Monad/FwdCDerivMonad.lean +++ /dev/null @@ -1,347 +0,0 @@ -import SciLean.Analysis.Calculus.FwdCDeriv - -namespace SciLean - -set_option linter.unusedVariables false in -class FwdCDerivMonad (K : Type) [RCLike K] (m : Type → Type) (m' : outParam $ Type → Type) [Monad m] [Monad m'] where - fwdCDerivM {X : Type} {Y : Type} [Vec K X] [Vec K Y] : ∀ (f : X → m Y) (x dx : X), m' (Y × Y) - - CDifferentiableM {X : Type} {Y : Type} [Vec K X] [Vec K Y] - : ∀ (f : X → m Y), Prop - - fwdCDerivM_pure {X Y : Type} [Vec K X] [Vec K Y] (f : X → Y) (hf : CDifferentiable K f) - : fwdCDerivM (fun x => pure (f:=m) (f x)) = fun x dx => pure (f:=m') (fwdCDeriv K f x dx) - fwdCDerivM_bind - {X Y Z : Type} [Vec K X] [Vec K Y] [Vec K Z] - (f : Y → m Z) (g : X → m Y) (hf : CDifferentiableM f) (hg : CDifferentiableM g) - : fwdCDerivM (fun x => g x >>= f) - = - fun x dx => do - let ydy ← fwdCDerivM g x dx - fwdCDerivM f ydy.1 ydy.2 - fwdCDerivM_pair {X : Type} {Y : Type} [Vec K X] [Vec K Y] -- is this really necessary? - (f : X → m Y) (hf : CDifferentiableM f) - : fwdCDerivM (fun x => do let y ← f x; pure (x,y)) - = - (fun x dx => do - let ydy ← fwdCDerivM f x dx - pure ((x,ydy.1),(dx,ydy.2))) - - - CDifferentiableM_pure {X : Type} {Y : Type} [Vec K X] [Vec K Y] - (f : X → Y) (hf : CDifferentiable K f) - : CDifferentiableM (fun x : X => pure (f x)) - CDifferentiableM_bind {X Y Z: Type} [Vec K X] [Vec K Y] [Vec K Z] - (f : Y → m Z) (g : X → m Y) - (hf : CDifferentiableM f) (hg : CDifferentiableM g) - : CDifferentiableM (fun x => g x >>= f) - CDifferentiableM_pair {X : Type} {Y : Type} [Vec K X] [Vec K Y] -- is this really necessary? - (f : X → m Y) (hf : CDifferentiableM f) - : CDifferentiableM (fun x => do let y ← f x; pure (x,y)) - -export FwdCDerivMonad (fwdCDerivM CDifferentiableM) - -attribute [fun_prop] CDifferentiableM -attribute [fun_trans] fwdCDerivM - -set_option deprecated.oldSectionVars true - -variable - (K : Type _) [RCLike K] - {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] [FwdCDerivMonad K m m'] - [LawfulMonad m] [LawfulMonad m'] - {X : Type} [Vec K X] - {Y : Type} [Vec K Y] - {Z : Type} [Vec K Z] - {E : ι → Type} [∀ i, Vec K (E i)] - -open FwdCDerivMonad - - -def fwdCDerivValM (x : m X) : m' (X × X) := do - fwdCDerivM K (fun _ : Unit => x) () () - -def CDifferentiableValM (x : m X) : Prop := - CDifferentiableM K (fun _ : Unit => x) - - --------------------------------------------------------------------------------- --- CDifferentiableM ----------------------------------------------------------- --------------------------------------------------------------------------------- -namespace CDifferentiableM - -@[fun_prop] -theorem pure_rule - : CDifferentiableM (m:=m) K (fun x : X => pure x) := -by - apply CDifferentiableM_pure - fun_prop - -@[fun_prop] -theorem const_rule (y : m Y) (hy : CDifferentiableValM K y) - : CDifferentiableM K (fun _ : X => y) := -by - have h : (fun _ : X => y) - = - fun _ : X => pure () >>= fun _ => y := by simp - rw[h] - apply CDifferentiableM_bind - apply hy - apply CDifferentiableM_pure - fun_prop - -@[fun_prop] -theorem comp_rule - (f : Y → m Z) (g : X → Y) - (hf : CDifferentiableM K f) (hg : CDifferentiable K g) - : CDifferentiableM K (fun x => f (g x)) := -by - rw[show ((fun x => f (g x)) - = - fun x => pure (g x) >>= f) by simp] - apply CDifferentiableM_bind _ _ hf - apply CDifferentiableM_pure g hg - -end CDifferentiableM - - --------------------------------------------------------------------------------- --- fwdCDerivM ------------------------------------------------------------------- --------------------------------------------------------------------------------- -namespace fwdCDerivM - -@[fun_trans] -theorem pure_rule - : fwdCDerivM (m:=m) K (fun x : X => pure x) = fun x dx => pure (x, dx) := -by - rw [fwdCDerivM_pure _ (by fun_prop)] - fun_trans - -@[fun_trans] -theorem const_rule (y : m Y) (hy : CDifferentiableValM K y) - : fwdCDerivM K (fun _ : X => y) = fun _ _ => fwdCDerivValM K y := -by - have h : (fun _ : X => y) - = - fun _ : X => pure () >>= fun _ => y := by simp - rw[h] - rw[fwdCDerivM_bind] - rw[fwdCDerivM_pure] - fun_trans - simp [fwdCDerivValM] - fun_prop - apply hy - apply CDifferentiableM_pure; fun_prop - -@[fun_trans] -theorem comp_rule - (f : Y → m Z) (g : X → Y) - (hf : CDifferentiableM K f) (hg : CDifferentiable K g) - : fwdCDerivM K (fun x => f (g x)) - = - fun x dx => - let ydy := fwdCDeriv K g x dx - fwdCDerivM K f ydy.1 ydy.2 := -by - conv => - lhs - rw[show ((fun x => f (g x)) - = - fun x => pure (g x) >>= f) by simp] - rw[FwdCDerivMonad.fwdCDerivM_bind f (fun x => pure (g x)) - hf (FwdCDerivMonad.CDifferentiableM_pure _ hg)] - simp[FwdCDerivMonad.fwdCDerivM_pure g hg] - -@[fun_trans] -theorem let_rule - (f : X → Y → m Z) (g : X → Y) - (hf : CDifferentiableM K (fun xy : X×Y => f xy.1 xy.2)) (hg : CDifferentiable K g) - : fwdCDerivM K (fun x => let y := g x; f x y) - = - fun x dx => - let ydy := fwdCDeriv K g x dx - fwdCDerivM K (fun xy : X×Y => f xy.1 xy.2) (x,ydy.1) (dx,ydy.2) := -by - let f' := (fun xy : X×Y => f xy.1 xy.2) - let g' := (fun x => (x,g x)) - have hg' : CDifferentiable K g' := by rw[show g' = (fun x => (x,g x)) by rfl]; fun_prop - conv => - lhs - rw[show ((fun x => f x (g x)) - = - fun x => pure (g' x) >>= f') by simp] - rw[fwdCDerivM_bind f' (fun x => pure (g' x)) hf (CDifferentiableM_pure g' hg')] - simp[fwdCDerivM_pure (K:=K) g' hg'] - -- fun_trans - -- simp - sorry_proof - -end fwdCDerivM - -end SciLean - - - --------------------------------------------------------------------------------- - -section CoreFunctionProperties - -open SciLean - -set_option deprecated.oldSectionVars true - -variable - (K : Type _) [RCLike K] - {m m'} [Monad m] [Monad m'] [FwdCDerivMonad K m m'] - [LawfulMonad m] [LawfulMonad m'] - {X : Type} [Vec K X] - {Y : Type} [Vec K Y] - {Z : Type} [Vec K Z] - {E : ι → Type} [∀ i, Vec K (E i)] - - --------------------------------------------------------------------------------- --- Pure.pure ------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem Pure.pure.arg_a0.CDifferentiableM_rule - (a0 : X → Y) - (ha0 : CDifferentiable K a0) - : CDifferentiableM K (fun x => Pure.pure (f:=m) (a0 x)) := -by - apply FwdCDerivMonad.CDifferentiableM_pure a0 ha0 - - -@[fun_trans] -theorem Pure.pure.arg_a0.fwdCDerivM_rule - (a0 : X → Y) - (ha0 : CDifferentiable K a0) - : fwdCDerivM K (fun x => pure (f:=m) (a0 x)) - = - fun x dx => pure (fwdCDeriv K a0 x dx) := -by - apply FwdCDerivMonad.fwdCDerivM_pure a0 ha0 - -@[simp, simp_core] -theorem Pure.pure.arg.CDifferentiableValM_rule (x : X) - : CDifferentiableValM K (pure (f:=m) x) := -by - unfold CDifferentiableValM - apply FwdCDerivMonad.CDifferentiableM_pure - fun_prop - -@[simp, simp_core] -theorem Pure.pure.arg.fwdCDerivValM_rule (x : X) - : fwdCDerivValM K (pure (f:=m) x) - = - pure (x,0) := -by - unfold fwdCDerivValM; rw[FwdCDerivMonad.fwdCDerivM_pure]; fun_trans; fun_prop - - --------------------------------------------------------------------------------- --- Bind.bind ------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem Bind.bind.arg_a0a1.CDifferentiableM_rule - (a0 : X → m Y) (a1 : X → Y → m Z) - (ha0 : CDifferentiableM K a0) (ha1 : CDifferentiableM K (fun (xy : X×Y) => a1 xy.1 xy.2)) - : CDifferentiableM K (fun x => Bind.bind (a0 x) (a1 x)) := -by - let g := fun x => do - let y ← a0 x - pure (x,y) - let f := fun xy : X×Y => a1 xy.1 xy.2 - - rw[show (fun x => Bind.bind (a0 x) (a1 x)) - = - fun x => g x >>= f by simp[f,g]] - - have hg : CDifferentiableM K (fun x => do let y ← a0 x; pure (x,y)) := - by apply FwdCDerivMonad.CDifferentiableM_pair a0 ha0 - have hf : CDifferentiableM K f := by simp[f]; fun_prop - - apply FwdCDerivMonad.CDifferentiableM_bind _ _ hf hg - - - -@[fun_trans] -theorem Bind.bind.arg_a0a1.fwdCDerivM_rule - (a0 : X → m Y) (a1 : X → Y → m Z) - (ha0 : CDifferentiableM K a0) (ha1 : CDifferentiableM K (fun (xy : X×Y) => a1 xy.1 xy.2)) - : (fwdCDerivM K (fun x => Bind.bind (a0 x) (a1 x))) - = - (fun x dx => do - let ydy ← fwdCDerivM K a0 x dx - fwdCDerivM K (fun (xy : X×Y) => a1 xy.1 xy.2) (x, ydy.1) (dx, ydy.2)) := -by - let g := fun x => do - let y ← a0 x - pure (x,y) - let f := fun xy : X×Y => a1 xy.1 xy.2 - - rw[show (fun x => Bind.bind (a0 x) (a1 x)) - = - fun x => g x >>= f by simp[f,g]] - - have hg : CDifferentiableM K (fun x => do let y ← a0 x; pure (x,y)) := - by apply FwdCDerivMonad.CDifferentiableM_pair a0 ha0 - have hf : CDifferentiableM K f := by simp [f]; fun_prop - - rw [FwdCDerivMonad.fwdCDerivM_bind _ _ hf hg] - simp [FwdCDerivMonad.fwdCDerivM_pair a0 ha0] - - --- d/ite ----------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem ite.arg_te.CDifferentiableM_rule - (c : Prop) [dec : Decidable c] (t e : X → m Y) - (ht : CDifferentiableM K t) (he : CDifferentiableM K e) - : CDifferentiableM K (fun x => ite c (t x) (e x)) := -by - induction dec - case isTrue h => simp[ht,h] - case isFalse h => simp[he,h] - - -@[fun_trans] -theorem ite.arg_te.fwdCDerivM_rule - (c : Prop) [dec : Decidable c] (t e : X → m Y) - : fwdCDerivM K (fun x => ite c (t x) (e x)) - = - fun y => - ite c (fwdCDerivM K t y) (fwdCDerivM K e y) := -by - induction dec - case isTrue h => ext y; simp[h] - case isFalse h => ext y; simp[h] - - -@[fun_prop] -theorem dite.arg_te.CDifferentiableM_rule - (c : Prop) [dec : Decidable c] - (t : c → X → m Y) (e : ¬c → X → m Y) - (ht : ∀ h, CDifferentiableM K (t h)) (he : ∀ h, CDifferentiableM K (e h)) - : CDifferentiableM K (fun x => dite c (fun h => t h x) (fun h => e h x)) := -by - induction dec - case isTrue h => simp[ht,h] - case isFalse h => simp[he,h] - - -@[fun_trans] -theorem dite.arg_te.fwdCDerivM_rule - (c : Prop) [dec : Decidable c] - (t : c → X → m Y) (e : ¬c → X → m Y) - : fwdCDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) - = - fun y => - dite c (fun h => fwdCDerivM K (t h) y) (fun h => fwdCDerivM K (e h) y) := -by - induction dec - case isTrue h => ext y; simp[h] - case isFalse h => ext y; simp[h] diff --git a/SciLean/Analysis/Calculus/Monad/FwdFDerivMonad.lean b/SciLean/Analysis/Calculus/Monad/FwdFDerivMonad.lean new file mode 100644 index 00000000..806196d7 --- /dev/null +++ b/SciLean/Analysis/Calculus/Monad/FwdFDerivMonad.lean @@ -0,0 +1,275 @@ +import SciLean.Analysis.Calculus.FwdFDeriv +import SciLean.Analysis.Calculus.Monad.DifferentiableMonad + +namespace SciLean + + +/-- `FwdFDerivMonad K m m'` states that the monad `m'` allows us to compute forward mode derivative +of functions in the monad `m`. The rought idea is that if the monad `m` stores some state `S` then +the monad `m'` should store `S⨯S` corresponding to the state and its derivative. Concretelly, for +`m = StateM S` we have `m' = StateM S×S`. + +This class provides two main functions, such that monadic function `(f : X → m Y)`: + - `fwdFDerivM K f` is generalization of forward mode derivative of `f` + - `DifferentiableM K f` is generalization of differentiability of `f` + +For `StateM S` the `fwdFDerivM` and `DifferentiableM` is: +``` + fwdFDerivM K f + = + fun x dx (s,ds) => + let ((y,s),(dy,ds)) := fwdFDeriv K (fun (x,s) => f x s) (x,s) (dx,ds) + ((y,dy),(s,ds)) + + DifferentiableM K f + = + Differentiable K (fun (x,s) => f x s) +``` +In short, `fwdFDerviM` also differentiates w.r.t. to the state variable and `DifferentiableM` checks +that a function is differentiable also w.r.t. to the state variable too. + +The nice property of this general definition is that it generalized to monad tranformer `StateT`. +Therefore we can nest state monads and still differentiate them. +-/ +class FwdFDerivMonad (K : Type) [RCLike K] (m : Type → Type) (m' : outParam $ Type → Type) [Monad m] [Monad m'] [DifferentiableMonad K m] where + /-- Forward mode derivative for monadic functions. + + For state monad, `m = StateM S`, this derivative also differentiates w.r.t to the state variable + ``` + fwdFDerivM K f + = + fun x dx (s,ds) => + let ((y,s),(dy,ds)) := fwdFDeriv K (fun (x,s) => f x s) (x,s) (dx,ds) + ((y,dy),(s,ds)) + ``` + -/ + fwdFDerivM {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → m Y) (x dx : X) : m' (Y × Y) + + /-- Monadic derivative of pure function is normal derivative. -/ + fwdFDerivM_pure {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → Y) (hf : Differentiable K f) : + fwdFDerivM (fun x => pure (f:=m) (f x)) = fun x dx => pure (f:=m') (fwdFDeriv K f x dx) + + /-- Monadic chain rule. -/ + fwdFDerivM_bind + {X Y Z : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + [NormedAddCommGroup Z] [NormedSpace K Z] + (f : Y → m Z) (g : X → m Y) (hf : DifferentiableM K f) (hg : DifferentiableM K g) : + fwdFDerivM (fun x => g x >>= f) + = + fun x dx => do + let ydy ← fwdFDerivM g x dx + fwdFDerivM f ydy.1 ydy.2 + + /-- Theorem allowing us to differentiate let bindings. + + Note: The role of this is still not completely clear to us. Is this really independent of the + previous two requirements? -/ + fwdFDerivM_pair {X Y : Type} [NormedAddCommGroup X] [NormedSpace K X] [NormedAddCommGroup Y] [NormedSpace K Y] + (f : X → m Y) (hf : DifferentiableM K f) : + fwdFDerivM (fun x => do let y ← f x; pure (x,y)) + = + (fun x dx => do + let ydy ← fwdFDerivM f x dx + pure ((x,ydy.1),(dx,ydy.2))) + + +export FwdFDerivMonad (fwdFDerivM) + +attribute [fun_trans] fwdFDerivM + +set_option deprecated.oldSectionVars true + +variable + (K : Type) [RCLike K] + {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] + [DifferentiableMonad K m] [FwdFDerivMonad K m m'] [LawfulMonad m] [LawfulMonad m'] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] + +open FwdFDerivMonad + +/-- Monadic derivative of a value. For example, in case of state monad the value `x : StateM S X` +is a function in `S` and it makes sense to take derivative of it. -/ +def fwdFDerivValM (x : m X) : m' (X × X) := do + fwdFDerivM K (fun _ : Unit => x) () () + +-------------------------------------------------------------------------------- +-- fwdFDerivM ------------------------------------------------------------------- +-------------------------------------------------------------------------------- +namespace fwdFDerivM + +open DifferentiableMonad + +@[fun_trans] +theorem pure_rule + : fwdFDerivM (m:=m) K (fun x : X => pure x) = fun x dx => pure (x, dx) := +by + rw [fwdFDerivM_pure _ (by fun_prop)] + fun_trans + +@[fun_trans] +theorem const_rule (y : m Y) (hy : DifferentiableValM K y) + : fwdFDerivM K (fun _ : X => y) = fun _ _ => fwdFDerivValM K y := +by + have h : (fun _ : X => y) + = + fun _ : X => pure () >>= fun _ => y := by simp + rw[h] + rw[fwdFDerivM_bind] + rw[fwdFDerivM_pure] + fun_trans + simp [fwdFDerivValM] + fun_prop + apply hy + apply DifferentiableM_pure; fun_prop + +@[fun_trans] +theorem comp_rule + (f : Y → m Z) (g : X → Y) + (hf : DifferentiableM K f) (hg : Differentiable K g) + : fwdFDerivM K (fun x => f (g x)) + = + fun x dx => + let ydy := fwdFDeriv K g x dx + fwdFDerivM K f ydy.1 ydy.2 := +by + conv => + lhs + rw[show ((fun x => f (g x)) + = + fun x => pure (g x) >>= f) by simp] + rw[FwdFDerivMonad.fwdFDerivM_bind f (fun x => pure (g x)) + hf (DifferentiableM_pure _ hg)] + simp[FwdFDerivMonad.fwdFDerivM_pure g hg] + +@[fun_trans] +theorem let_rule + (f : X → Y → m Z) (g : X → Y) + (hf : DifferentiableM K (fun xy : X×Y => f xy.1 xy.2)) (hg : Differentiable K g) + : fwdFDerivM K (fun x => let y := g x; f x y) + = + fun x dx => + let ydy := fwdFDeriv K g x dx + fwdFDerivM K (fun xy : X×Y => f xy.1 xy.2) (x,ydy.1) (dx,ydy.2) := +by + let f' := (fun xy : X×Y => f xy.1 xy.2) + let g' := (fun x => (x,g x)) + have hg' : Differentiable K g' := by rw[show g' = (fun x => (x,g x)) by rfl]; fun_prop + conv => + lhs + rw[show ((fun x => f x (g x)) + = + fun x => pure (g' x) >>= f') by simp] + rw[fwdFDerivM_bind f' (fun x => pure (g' x)) hf (DifferentiableM_pure g' hg')] + simp[fwdFDerivM_pure (K:=K) g' hg'] + fun_trans + simp + +end fwdFDerivM + +end SciLean + + + +-------------------------------------------------------------------------------- + +section CoreFunctionProperties + +open SciLean DifferentiableMonad + +set_option deprecated.oldSectionVars true + +variable + (K : Type) [RCLike K] + {m m'} [Monad m] [Monad m'] + [DifferentiableMonad K m] [FwdFDerivMonad K m m'] [LawfulMonad m] [LawfulMonad m'] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] + {E : ι → Type} [∀ i, Vec K (E i)] + + +-------------------------------------------------------------------------------- +-- Pure.pure ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem Pure.pure.arg_a0.fwdFDerivM_rule + (a0 : X → Y) + (ha0 : Differentiable K a0) + : fwdFDerivM K (fun x => pure (f:=m) (a0 x)) + = + fun x dx => pure (fwdFDeriv K a0 x dx) := +by + apply FwdFDerivMonad.fwdFDerivM_pure a0 ha0 + +@[simp, simp_core] +theorem Pure.pure.arg.fwdFDerivValM_rule (x : X) + : fwdFDerivValM K (pure (f:=m) x) + = + pure (x,0) := +by + unfold fwdFDerivValM; rw[FwdFDerivMonad.fwdFDerivM_pure]; fun_trans; fun_prop + + +-------------------------------------------------------------------------------- +-- Bind.bind ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem Bind.bind.arg_a0a1.fwdFDerivM_rule + (a0 : X → m Y) (a1 : X → Y → m Z) + (ha0 : DifferentiableM K a0) (ha1 : DifferentiableM K (fun (xy : X×Y) => a1 xy.1 xy.2)) + : (fwdFDerivM K (fun x => Bind.bind (a0 x) (a1 x))) + = + (fun x dx => do + let ydy ← fwdFDerivM K a0 x dx + fwdFDerivM K (fun (xy : X×Y) => a1 xy.1 xy.2) (x, ydy.1) (dx, ydy.2)) := +by + let g := fun x => do + let y ← a0 x + pure (x,y) + let f := fun xy : X×Y => a1 xy.1 xy.2 + + rw[show (fun x => Bind.bind (a0 x) (a1 x)) + = + fun x => g x >>= f by simp[f,g]] + + have hg : DifferentiableM K (fun x => do let y ← a0 x; pure (x,y)) := + by apply DifferentiableM_pair a0 ha0 + have hf : DifferentiableM K f := by simp [f]; fun_prop + + rw [FwdFDerivMonad.fwdFDerivM_bind _ _ hf hg] + simp [FwdFDerivMonad.fwdFDerivM_pair a0 ha0] + + +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem ite.arg_te.fwdFDerivM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + : fwdFDerivM K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (fwdFDerivM K t y) (fwdFDerivM K e y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + +@[fun_trans] +theorem dite.arg_te.fwdFDerivM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + : fwdFDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) + = + fun y => + dite c (fun h => fwdFDerivM K (t h) y) (fun h => fwdFDerivM K (e h) y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] diff --git a/SciLean/Analysis/Calculus/Monad/Id.lean b/SciLean/Analysis/Calculus/Monad/Id.lean index c7bfad94..f62d8f14 100644 --- a/SciLean/Analysis/Calculus/Monad/Id.lean +++ b/SciLean/Analysis/Calculus/Monad/Id.lean @@ -1,5 +1,5 @@ -import SciLean.Analysis.Calculus.Monad.FwdCDerivMonad -import SciLean.Analysis.Calculus.Monad.RevCDerivMonad +import SciLean.Analysis.Calculus.Monad.FwdFDerivMonad +import SciLean.Analysis.Calculus.Monad.RevFDerivMonad namespace SciLean @@ -35,126 +35,70 @@ instance : Coe (Id' X) X := ⟨fun x => x.run⟩ instance : Coe X (Id' X) := ⟨fun x => pure x⟩ variable - {K : Type _} [RCLike K] - -noncomputable -instance : FwdCDerivMonad K Id' Id' where - fwdCDerivM f := fun x dx => pure (fwdCDeriv K (fun x => (f x).run) x dx) - CDifferentiableM f := CDifferentiable K (fun x => (f x).run) - fwdCDerivM_pure f := by simp[pure] - fwdCDerivM_bind := by simp[Id',Bind.bind]; sorry_proof - fwdCDerivM_pair y := by intros; simp; sorry_proof - CDifferentiableM_pure := by simp[pure] - CDifferentiableM_bind := by intros; simp[bind]; sorry_proof - CDifferentiableM_pair y := by intros; simp[bind,pure]; fun_prop + {K : Type} [RCLike K] +instance : DifferentiableMonad K Id' where + DifferentiableM f := Differentiable K (fun x => (f x).run) + DifferentiableM_pure := by simp[pure] + DifferentiableM_bind := by intros; simp[bind]; sorry_proof + DifferentiableM_pair y := by intros; simp[bind,pure]; fun_prop noncomputable -instance : RevCDerivMonad K Id' Id' where - revCDerivM f := fun x => - let ydf := revCDeriv K (fun x => (f x).run) x - pure ((ydf.1, fun dy => pure (ydf.2 dy))) - HasAdjDiffM f := HasAdjDiff K (fun x => (f x).run) - revCDerivM_pure f := by intros; funext; simp[pure,revCDeriv] - revCDerivM_bind := by intros; simp; sorry_proof - revCDerivM_pair y := by intros; simp[Bind.bind]; funext x; sorry_proof - HasAdjDiffM_pure := by simp[pure] - HasAdjDiffM_bind := by intros; simp[bind]; sorry_proof - HasAdjDiffM_pair y := by intros; simp[bind, pure]; fun_prop +instance : FwdFDerivMonad K Id' Id' where + fwdFDerivM f := fun x dx => pure (fwdFDeriv K (fun x => (f x).run) x dx) + fwdFDerivM_pure f := by simp[pure] + fwdFDerivM_bind := by simp[Id',Bind.bind]; sorry_proof + fwdFDerivM_pair y := by intros; simp; sorry_proof +noncomputable +instance : RevFDerivMonad K Id' Id' where + revFDerivM f := fun x => + let ydf := revFDeriv K (fun x => (f x).run) x + pure (ydf.1, fun dy => pure (ydf.2 dy)) + revFDerivM_pure f := by simp[pure] + revFDerivM_bind := by simp[Id',Bind.bind]; sorry_proof + revFDerivM_pair y := by intros; simp; sorry_proof end SciLean open SciLean -section OnVec +section OnNormedSpace variable - {K : Type _} [RCLike K] - {X : Type} [Vec K X] - {Y : Type} [Vec K Y] - {Z : Type} [Vec K Z] - {E : ι → Type _} [∀ i, Vec K (E i)] + {K : Type} [RCLike K] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] @[fun_prop] -theorem Id'.run.arg_x.CDifferentiable_rule - (a : X → Id' Y) (ha : CDifferentiableM K a) - : CDifferentiable K (fun x => Id'.run (a x)) := ha +theorem Id'.run.arg_x.Differentiable_rule + (a : X → Id' Y) (ha : DifferentiableM K a) : + Differentiable K (fun x => Id'.run (a x)) := ha @[fun_trans] -theorem Id'.run.arg_x.fwdCDeriv_rule (a : X → Id' Y) - : fwdCDeriv K (fun x => Id'.run (a x)) +theorem Id'.run.arg_x.fwdFDeriv_rule (a : X → Id' Y) : + fwdFDeriv K (fun x => Id'.run (a x)) = - fun x dx => (fwdCDerivM K a x dx).run := by rfl + fun x dx => (fwdFDerivM K a x dx).run := by rfl -end OnVec +end OnNormedSpace -section OnSemiInnerProductSpace +section OnAdjointSpace variable - {K : Type _} [RCLike K] - {X : Type} [SemiInnerProductSpace K X] - {Y : Type} [SemiInnerProductSpace K Y] - {Z : Type} [SemiInnerProductSpace K Z] - {E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)] - - -@[fun_prop] -theorem Id'.run.arg_x.HasAdjDiff_rule - (a : X → Id' Y) (ha : HasAdjDiffM K a) - : HasAdjDiff K (fun x => Id'.run (a x)) := ha - - -@[fun_trans] -theorem Id'.run.arg_x.revCDeriv_rule (a : X → Id' Y) - : revCDeriv K (fun x => Id'.run (a x)) - = - fun x => - let ydf := (revCDerivM K a x).run - (ydf.1, fun dy => (ydf.2 dy).run) := by rfl - - -@[fun_prop] -theorem Pure.pure.arg_a0.HasAdjDiff_rule - (a0 : X → Y) (ha0 : HasAdjDiff K a0) : - HasAdjDiffM K (fun x => Pure.pure (f:=Id') (a0 x)) := by - simp[Pure.pure,HasAdjDiffM]; fun_prop - - -@[fun_trans] -theorem Pure.pure.arg_a0.fwdCDeriv_rule - (a0 : X → Y) : - fwdCDerivM K (fun x => Pure.pure (f:=Id') (a0 x)) - = - fun x dx => - let ydy := fwdCDeriv K a0 x dx - pure ydy := by rfl - - -@[fun_prop] -theorem Bind.bind.arg_a0a1.HasAdjDiff_rule_on_Id' - (a0 : X → Y) (a1 : X → Y → Z) - (ha0 : HasAdjDiff K a0) (ha1 : HasAdjDiff K (fun (x,y) => a1 x y)) : - HasAdjDiffM K (fun x => Bind.bind (m:=Id') ⟨a0 x⟩ (fun y => ⟨a1 x y⟩)) := by - simp[Bind.bind,HasAdjDiffM]; fun_prop - + {K : Type} [RCLike K] + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] [CompleteSpace Z] @[fun_trans] -theorem Bind.bind.arg_a0a1.revCDerivM_rule_on_Id' - (a0 : X → Y) (a1 : X → Y → Z) - (ha0 : HasAdjDiff K a0) (ha1 : HasAdjDiff K (fun (x,y) => a1 x y)) : - (revCDerivM (m:=Id') K (fun x => Bind.bind ⟨a0 x⟩ (fun y => ⟨a1 x y⟩))) +theorem Id'.run.arg_x.revFDeriv_rule (a : X → Id' Y) : + revFDeriv K (fun x => Id'.run (a x)) = fun x => - let ydg' := revCDeriv K a0 x - let zdf' := revCDeriv K (fun (x,y) => a1 x y) (x,ydg'.1) - ⟨(zdf'.1, - fun dz' => - let dxy' := zdf'.2 dz' - let dx' := ydg'.2 dxy'.2 - ⟨dxy'.1 + dx'⟩)⟩ := by - simp[revCDerivM,Bind.bind]; fun_trans; simp[revCDeriv,revCDerivUpdate]; sorry_proof - - + let xda := (revFDerivM K a x).run + (xda.1, + fun dy => (xda.2 dy).run) := by rfl -end OnSemiInnerProductSpace +end OnAdjointSpace diff --git a/SciLean/Analysis/Calculus/Monad/RevCDerivMonad.lean b/SciLean/Analysis/Calculus/Monad/RevCDerivMonad.lean deleted file mode 100644 index fca3c186..00000000 --- a/SciLean/Analysis/Calculus/Monad/RevCDerivMonad.lean +++ /dev/null @@ -1,367 +0,0 @@ -import SciLean.Analysis.Calculus.RevCDeriv - -namespace SciLean - - -set_option linter.unusedVariables false in -class RevCDerivMonad (K : Type) [RCLike K] (m : Type → Type) (m' : outParam $ Type → Type) [Monad m] [Monad m'] where - revCDerivM {X : Type} {Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] : ∀ (f : X → m Y) (x : X), m (Y × (Y → m' X)) - - HasAdjDiffM {X : Type} {Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] - : ∀ (f : X → m Y), Prop - - revCDerivM_pure {X Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] (f : X → Y) (hf : HasAdjDiff K f) - : revCDerivM (fun x => pure (f:=m) (f x)) = fun x => let ydf := revCDeriv K f x; pure (ydf.1, fun dy => pure (ydf.2 dy)) - revCDerivM_bind - {X Y Z : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] [SemiInnerProductSpace K Z] - (f : Y → m Z) (g : X → m Y) (hf : HasAdjDiffM f) (hg : HasAdjDiffM g) - : revCDerivM (fun x => g x >>= f) - = - fun x => do - let ydg ← revCDerivM g x - let zdf ← revCDerivM f ydg.1 - pure (zdf.1, fun dz => zdf.2 dz >>= ydg.2) - revCDerivM_pair {X : Type} {Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] -- is this really necessary? - (f : X → m Y) (hf : HasAdjDiffM f) - : revCDerivM (fun x => do let y ← f x; pure (x,y)) - = - (fun x => do - let ydf ← revCDerivM f x - pure ((x,ydf.1), fun dxy : X×Y => do let dx ← ydf.2 dxy.2; pure (dxy.1 + dx))) - - - HasAdjDiffM_pure {X : Type} {Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] - (f : X → Y) (hf : HasAdjDiff K f) - : HasAdjDiffM (fun x : X => pure (f x)) - HasAdjDiffM_bind {X Y Z: Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] [SemiInnerProductSpace K Z] - (f : Y → m Z) (g : X → m Y) - (hf : HasAdjDiffM f) (hg : HasAdjDiffM g) - : HasAdjDiffM (fun x => g x >>= f) - HasAdjDiffM_pair {X : Type} {Y : Type} [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] -- is this really necessary? - (f : X → m Y) (hf : HasAdjDiffM f) - : HasAdjDiffM (fun x => do let y ← f x; pure (x,y)) - - -export RevCDerivMonad (revCDerivM HasAdjDiffM) - -attribute [fun_trans] revCDerivM -attribute [fun_prop] HasAdjDiffM - -set_option deprecated.oldSectionVars true - -variable - (K : Type _) [RCLike K] - {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] [RevCDerivMonad K m m'] - [LawfulMonad m] [LawfulMonad m'] - {X : Type} [SemiInnerProductSpace K X] - {Y : Type} [SemiInnerProductSpace K Y] - {Z : Type} [SemiInnerProductSpace K Z] - {E : ι → Type} [∀ i, SemiInnerProductSpace K (E i)] - -open RevCDerivMonad - -def revCDerivValM (x : m X) : m (X × (X → m' Unit)) := do - revCDerivM K (fun _ : Unit => x) () - -def HasAdjDiffValM (x : m X) : Prop := - HasAdjDiffM K (fun _ : Unit => x) - - --------------------------------------------------------------------------------- --- HasAdjDiffM ----------------------------------------------------------- --------------------------------------------------------------------------------- -namespace HasAdjDiffM - --- id_rule does not make sense - -@[fun_prop] -theorem const_rule (y : m Y) (hy : HasAdjDiffValM K y) - : HasAdjDiffM K (fun _ : X => y) := -by - have h : (fun _ : X => y) - = - fun _ : X => pure () >>= fun _ => y := by simp - rw[h] - apply HasAdjDiffM_bind - apply hy - apply HasAdjDiffM_pure - fun_prop - -@[fun_prop] -theorem comp_rule - (f : Y → m Z) (g : X → Y) - (hf : HasAdjDiffM K f) (hg : HasAdjDiff K g) - : HasAdjDiffM K (fun x => f (g x)) := -by - rw[show ((fun x => f (g x)) - = - fun x => pure (g x) >>= f) by simp] - apply HasAdjDiffM_bind _ _ hf - apply HasAdjDiffM_pure g hg - - - -end HasAdjDiffM - --------------------------------------------------------------------------------- --- revCDerivM ------------------------------------------------------------------- --------------------------------------------------------------------------------- -namespace revCDerivM - --- id_rule does not make sense - - -@[fun_trans] -theorem const_rule (y : m Y) (hy : HasAdjDiffValM K y) - : revCDerivM K (fun _ : X => y) - = - (fun _ => do - let ydy ← revCDerivValM K y - pure (ydy.1, - fun dy' => do - let _ ← ydy.2 dy' - pure 0)) := -by - have h : (fun _ : X => y) - = - fun _ : X => pure () >>= fun _ => y := by simp - rw[h] - rw[revCDerivM_bind] - rw[revCDerivM_pure] - fun_trans - simp [revCDerivValM] - fun_prop - apply hy - apply HasAdjDiffM_pure; fun_prop - -@[fun_trans] -theorem comp_rule - (f : Y → m Z) (g : X → Y) - (hf : HasAdjDiffM K f) (hg : HasAdjDiff K g) - : revCDerivM K (fun x => f (g x)) - = - (fun x => do - let ydg := revCDeriv K g x - let zdf ← revCDerivM K f ydg.1 - pure (zdf.1, - fun dz => do - let dy ← zdf.2 dz - pure (ydg.2 dy))) := -by - conv => - lhs - rw[show ((fun x => f (g x)) - = - fun x => pure (g x) >>= f) by simp] - rw[revCDerivM_bind f (fun x => pure (g x)) - hf (HasAdjDiffM_pure _ hg)] - simp[revCDerivM_pure g hg] - rfl - -@[fun_trans] -theorem let_rule - (f : X → Y → m Z) (g : X → Y) - (hf : HasAdjDiffM K (fun xy : X×Y => f xy.1 xy.2)) (hg : HasAdjDiff K g) - : revCDerivM K (fun x => let y := g x; f x y) - = - (fun x => do - let ydg := revCDeriv K g x - let zdf ← revCDerivM K (fun xy : X×Y => f xy.1 xy.2) (x,ydg.1) - pure (zdf.1, - fun dz => do - let dxy ← zdf.2 dz - let dx := ydg.2 dxy.2 - pure (dxy.1 + dx))) := -by - let f' := (fun xy : X×Y => f xy.1 xy.2) - let g' := (fun x => (x,g x)) - have hg' : HasAdjDiff K g' := by rw[show g' = (fun x => (x,g x)) by rfl]; fun_prop - conv => - lhs - rw[show ((fun x => f x (g x)) - = - fun x => pure (g' x) >>= f') by simp] - rw[revCDerivM_bind f' (fun x => pure (g' x)) hf (HasAdjDiffM_pure g' hg')] - simp[revCDerivM_pure (K:=K) g' hg'] - -- fun_trans; simp - sorry_proof - -end revCDerivM - - -end SciLean - - --------------------------------------------------------------------------------- - -section CoreFunctionProperties - -open SciLean - -set_option deprecated.oldSectionVars true - -variable - (K : Type _) [RCLike K] - {m m'} [Monad m] [Monad m'] [RevCDerivMonad K m m'] - [LawfulMonad m] [LawfulMonad m'] - {X : Type} [SemiInnerProductSpace K X] - {Y : Type} [SemiInnerProductSpace K Y] - {Z : Type} [SemiInnerProductSpace K Z] - {E : ι → Type} [∀ i, SemiInnerProductSpace K (E i)] - - --------------------------------------------------------------------------------- --- Pure.pure ------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem Pure.pure.arg_a0.HasAdjDiffM_rule - (a0 : X → Y) - (ha0 : HasAdjDiff K a0) - : HasAdjDiffM K (fun x => Pure.pure (f:=m) (a0 x)) := -by - apply RevCDerivMonad.HasAdjDiffM_pure a0 ha0 - - -@[fun_trans] -theorem Pure.pure.arg_a0.revCDerivM_rule - (a0 : X → Y) - (ha0 : HasAdjDiff K a0) - : revCDerivM K (fun x => pure (f:=m) (a0 x)) - = - (fun x => do - let ydf := revCDeriv K a0 x - pure (ydf.1, fun dy => pure (ydf.2 dy))):= -by - apply RevCDerivMonad.revCDerivM_pure a0 ha0 - - -@[simp, simp_core] -theorem Pure.pure.HasAdjDiffValM_rule (x : X) - : HasAdjDiffValM K (pure (f:=m) x) := -by - unfold HasAdjDiffValM - apply RevCDerivMonad.HasAdjDiffM_pure - fun_prop - - -@[simp, simp_core] -theorem Pure.pure.arg.revCDerivValM_rule (x : X) - : revCDerivValM K (pure (f:=m) x) - = - pure (x,fun dy => pure 0) := -by - unfold revCDerivValM; rw[RevCDerivMonad.revCDerivM_pure]; fun_trans; fun_prop - - --------------------------------------------------------------------------------- --- Bind.bind ------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem Bind.bind.arg_a0a1.HasAdjDiffM_rule - (a0 : X → m Y) (a1 : X → Y → m Z) - (ha0 : HasAdjDiffM K a0) (ha1 : HasAdjDiffM K (fun (xy : X×Y) => a1 xy.1 xy.2)) - : HasAdjDiffM K (fun x => Bind.bind (a0 x) (a1 x)) := -by - let g := fun x => do - let y ← a0 x - pure (x,y) - let f := fun xy : X×Y => a1 xy.1 xy.2 - - rw[show (fun x => Bind.bind (a0 x) (a1 x)) - = - fun x => g x >>= f by simp[f,g]] - - have hg : HasAdjDiffM K (fun x => do let y ← a0 x; pure (x,y)) := - by apply RevCDerivMonad.HasAdjDiffM_pair a0 ha0 - have hf : HasAdjDiffM K f := by simp[f]; fun_prop - - apply RevCDerivMonad.HasAdjDiffM_bind _ _ hf hg - - - -@[fun_trans] -theorem Bind.bind.arg_a0a1.revCDerivM_rule - (a0 : X → m Y) (a1 : X → Y → m Z) - (ha0 : HasAdjDiffM K a0) (ha1 : HasAdjDiffM K (fun (xy : X×Y) => a1 xy.1 xy.2)) - : (revCDerivM K (fun x => Bind.bind (a0 x) (a1 x))) - = - (fun x => do - let ydg ← revCDerivM K a0 x - let zdf ← revCDerivM K (fun (xy : X×Y) => a1 xy.1 xy.2) (x,ydg.1) - pure (zdf.1, - fun dz => do - let dxy ← zdf.2 dz - let dx ← ydg.2 dxy.2 - pure (dxy.1 + dx))) := -by - let g := fun x => do - let y ← a0 x - pure (x,y) - let f := fun xy : X×Y => a1 xy.1 xy.2 - - rw[show (fun x => Bind.bind (a0 x) (a1 x)) - = - fun x => g x >>= f by simp[f,g]] - - have hg : HasAdjDiffM K (fun x => do let y ← a0 x; pure (x,y)) := - by apply RevCDerivMonad.HasAdjDiffM_pair a0 ha0 - have hf : HasAdjDiffM K f := by simp[f]; fun_prop - - rw [RevCDerivMonad.revCDerivM_bind _ _ hf hg] - simp [RevCDerivMonad.revCDerivM_pair a0 ha0] - - --------------------------------------------------------------------------------- --- d/ite ----------------------------------------------------------------------- --------------------------------------------------------------------------------- - -@[fun_prop] -theorem ite.arg_te.HasAdjDiffM_rule - (c : Prop) [dec : Decidable c] (t e : X → m Y) - (ht : HasAdjDiffM K t) (he : HasAdjDiffM K e) - : HasAdjDiffM K (fun x => ite c (t x) (e x)) := -by - induction dec - case isTrue h => simp[ht,h] - case isFalse h => simp[he,h] - - -@[fun_trans] -theorem ite.arg_te.revCDerivM_rule - (c : Prop) [dec : Decidable c] (t e : X → m Y) - : revCDerivM K (fun x => ite c (t x) (e x)) - = - fun y => - ite c (revCDerivM K t y) (revCDerivM K e y) := -by - induction dec - case isTrue h => ext y; simp[h] - case isFalse h => ext y; simp[h] - - -@[fun_prop] -theorem dite.arg_te.HasAdjDiffM_rule - (c : Prop) [dec : Decidable c] - (t : c → X → m Y) (e : ¬c → X → m Y) - (ht : ∀ h, HasAdjDiffM K (t h)) (he : ∀ h, HasAdjDiffM K (e h)) - : HasAdjDiffM K (fun x => dite c (fun h => t h x) (fun h => e h x)) := -by - induction dec - case isTrue h => simp[ht,h] - case isFalse h => simp[he,h] - - -@[fun_trans] -theorem dite.arg_te.revCDerivM_rule - (c : Prop) [dec : Decidable c] - (t : c → X → m Y) (e : ¬c → X → m Y) - : revCDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) - = - fun y => - dite c (fun h => revCDerivM K (t h) y) (fun h => revCDerivM K (e h) y) := -by - induction dec - case isTrue h => ext y; simp[h] - case isFalse h => ext y; simp[h] diff --git a/SciLean/Analysis/Calculus/Monad/RevFDerivMonad.lean b/SciLean/Analysis/Calculus/Monad/RevFDerivMonad.lean new file mode 100644 index 00000000..3fdd7a75 --- /dev/null +++ b/SciLean/Analysis/Calculus/Monad/RevFDerivMonad.lean @@ -0,0 +1,285 @@ +import SciLean.Analysis.Calculus.RevFDeriv +import SciLean.Analysis.Calculus.Monad.DifferentiableMonad + +namespace SciLean + +/-- `FwdFDerivMonad K m m'` states that the monad `m'` allows us to compute reverse pass of the + reverse derivative of functions in the monad `m`. The rought idea is that if the monad `m` reads +some state `S` then the monad `m'` should write into `S`. State monad reads and writes, so for +`m = StateM S` we have `m' = StateM S`. + +This class provides two main functions, such that monadic function `(f : X → m Y)`: + - `revFDerivM K f` is generalization of reverse mode derivative of `f` + - `DifferentiableM K f` is generalization of differentiability of `f` + +For `StateM S` the `revFDerivM` and `DifferentiableM` is: +``` + revFDerivM K f + = + fun x s => + let ((y,s),df) := revFDeriv K (fun (x,s) => f x s) (x,s) + ((y,s), fun dy ds => df (dy,ds)) + + DifferentiableM K f + = + Differentiable K (fun (x,s) => f x s) +``` +In short, `revFDerviM` also differentiates w.r.t. to the state variable and `DifferentiableM` checks +that a function is differentiable also w.r.t. to the state variable too. + +-/ +class RevFDerivMonad (K : Type) [RCLike K] (m : Type → Type) (m' : outParam $ Type → Type) [Monad m] [Monad m'] [DifferentiableMonad K m] where + + revFDerivM + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + (f : X → m Y) (x : X) : m (Y × (Y → m' X)) + + revFDerivM_pure + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + (f : X → Y) (hf : Differentiable K f) : + revFDerivM (fun x => pure (f:=m) (f x)) = fun x => let ydf := revFDeriv K f x; pure (ydf.1, fun dy => pure (ydf.2 dy)) + + revFDerivM_bind + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] [CompleteSpace Z] + (f : Y → m Z) (g : X → m Y) (hf : DifferentiableM K f) (hg : DifferentiableM K g) + : revFDerivM (fun x => g x >>= f) + = + fun x => do + let ydg ← revFDerivM g x + let zdf ← revFDerivM f ydg.1 + pure (zdf.1, fun dz => zdf.2 dz >>= ydg.2) + + revFDerivM_pair + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + (f : X → m Y) (hf : DifferentiableM K f) + : revFDerivM (fun x => do let y ← f x; pure (x,y)) + = + (fun x => do + let ydf ← revFDerivM f x + pure ((x,ydf.1), fun dxy : X×Y => do let dx ← ydf.2 dxy.2; pure (dxy.1 + dx))) + + +export RevFDerivMonad (revFDerivM) + +attribute [fun_trans] revFDerivM + +set_option deprecated.oldSectionVars true + +variable + (K : Type _) [RCLike K] + {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] + [DifferentiableMonad K m] [RevFDerivMonad K m m'] + [LawfulMonad m] [LawfulMonad m'] + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] [CompleteSpace Z] + +open RevFDerivMonad DifferentiableMonad + +def revFDerivValM (x : m X) : m (X × (X → m' Unit)) := do + revFDerivM K (fun _ : Unit => x) () + + + +-------------------------------------------------------------------------------- +-- revFDerivM ------------------------------------------------------------------- +-------------------------------------------------------------------------------- +namespace revFDerivM + +-- id_rule does not make sense + + +@[fun_trans] +theorem const_rule (y : m Y) (hy : DifferentiableValM K y) + : revFDerivM K (fun _ : X => y) + = + (fun _ => do + let ydy ← revFDerivValM K y + pure (ydy.1, + fun dy' => do + let _ ← ydy.2 dy' + pure 0)) := +by + have h : (fun _ : X => y) + = + fun _ : X => pure () >>= fun _ => y := by simp + rw[h] + rw[revFDerivM_bind] + rw[revFDerivM_pure] + fun_trans + simp [revFDerivValM] + fun_prop + apply hy + apply DifferentiableM_pure; fun_prop + +@[fun_trans] +theorem comp_rule + (f : Y → m Z) (g : X → Y) + (hf : DifferentiableM K f) (hg : Differentiable K g) + : revFDerivM K (fun x => f (g x)) + = + (fun x => do + let ydg := revFDeriv K g x + let zdf ← revFDerivM K f ydg.1 + pure (zdf.1, + fun dz => do + let dy ← zdf.2 dz + pure (ydg.2 dy))) := +by + conv => + lhs + rw[show ((fun x => f (g x)) + = + fun x => pure (g x) >>= f) by simp] + rw[revFDerivM_bind f (fun x => pure (g x)) + hf (DifferentiableM_pure _ hg)] + simp[revFDerivM_pure g hg] + +@[fun_trans] +theorem let_rule + (f : X → Y → m Z) (g : X → Y) + (hf : DifferentiableM K (fun xy : X×Y => f xy.1 xy.2)) (hg : Differentiable K g) + : revFDerivM K (fun x => let y := g x; f x y) + = + (fun x => do + let ydg := revFDeriv K g x + let zdf ← revFDerivM K (fun xy : X×Y => f xy.1 xy.2) (x,ydg.1) + pure (zdf.1, + fun dz => do + let dxy ← zdf.2 dz + let dx := ydg.2 dxy.2 + pure (dxy.1 + dx))) := +by + let f' := (fun xy : X×Y => f xy.1 xy.2) + let g' := (fun x => (x,g x)) + have hg' : Differentiable K g' := by rw[show g' = (fun x => (x,g x)) by rfl]; fun_prop + conv => + lhs + rw[show ((fun x => f x (g x)) + = + fun x => pure (g' x) >>= f') by simp] + rw[revFDerivM_bind f' (fun x => pure (g' x)) hf (DifferentiableM_pure g' hg')] + simp[revFDerivM_pure (K:=K) g' hg'] + -- fun_trans; simp + sorry_proof + +end revFDerivM + + +end SciLean + + +-------------------------------------------------------------------------------- + +section CoreFunctionProperties + +open SciLean DifferentiableMonad + +set_option deprecated.oldSectionVars true + +variable + (K : Type _) [RCLike K] + {m m'} [Monad m] [Monad m'] [DifferentiableMonad K m] [RevFDerivMonad K m m'] + [LawfulMonad m] [LawfulMonad m'] + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] [CompleteSpace Z] + {E : ι → Type} [∀ i, SemiInnerProductSpace K (E i)] + + +-------------------------------------------------------------------------------- +-- Pure.pure ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem Pure.pure.arg_a0.revFDerivM_rule + (a0 : X → Y) + (ha0 : Differentiable K a0) + : revFDerivM K (fun x => pure (f:=m) (a0 x)) + = + (fun x => do + let ydf := revFDeriv K a0 x + pure (ydf.1, fun dy => pure (ydf.2 dy))):= +by + apply RevFDerivMonad.revFDerivM_pure a0 ha0 + +@[simp, simp_core] +theorem Pure.pure.arg.revFDerivValM_rule (x : X) + : revFDerivValM K (pure (f:=m) x) + = + pure (x,fun dy => pure 0) := +by + unfold revFDerivValM; rw[RevFDerivMonad.revFDerivM_pure]; fun_trans; fun_prop + + +-------------------------------------------------------------------------------- +-- Bind.bind ------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[fun_trans] +theorem Bind.bind.arg_a0a1.revFDerivM_rule + (a0 : X → m Y) (a1 : X → Y → m Z) + (ha0 : DifferentiableM K a0) (ha1 : DifferentiableM K (fun (xy : X×Y) => a1 xy.1 xy.2)) + : (revFDerivM K (fun x => Bind.bind (a0 x) (a1 x))) + = + (fun x => do + let ydg ← revFDerivM K a0 x + let zdf ← revFDerivM K (fun (xy : X×Y) => a1 xy.1 xy.2) (x,ydg.1) + pure (zdf.1, + fun dz => do + let dxy ← zdf.2 dz + let dx ← ydg.2 dxy.2 + pure (dxy.1 + dx))) := +by + let g := fun x => do + let y ← a0 x + pure (x,y) + let f := fun xy : X×Y => a1 xy.1 xy.2 + + rw[show (fun x => Bind.bind (a0 x) (a1 x)) + = + fun x => g x >>= f by simp[f,g]] + + have hg : DifferentiableM K (fun x => do let y ← a0 x; pure (x,y)) := + by apply DifferentiableM_pair a0 ha0 + have hf : DifferentiableM K f := by simp[f]; fun_prop + + rw [RevFDerivMonad.revFDerivM_bind _ _ hf hg] + simp [RevFDerivMonad.revFDerivM_pair a0 ha0] + + +-------------------------------------------------------------------------------- +-- d/ite ----------------------------------------------------------------------- +-------------------------------------------------------------------------------- + + +@[fun_trans] +theorem ite.arg_te.revFDerivM_rule + (c : Prop) [dec : Decidable c] (t e : X → m Y) + : revFDerivM K (fun x => ite c (t x) (e x)) + = + fun y => + ite c (revFDerivM K t y) (revFDerivM K e y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] + + +@[fun_trans] +theorem dite.arg_te.revFDerivM_rule + (c : Prop) [dec : Decidable c] + (t : c → X → m Y) (e : ¬c → X → m Y) + : revFDerivM K (fun x => dite c (fun h => t h x) (fun h => e h x)) + = + fun y => + dite c (fun h => revFDerivM K (t h) y) (fun h => revFDerivM K (e h) y) := +by + induction dec + case isTrue h => ext y; simp[h] + case isFalse h => ext y; simp[h] diff --git a/SciLean/Analysis/Calculus/Monad/StateT.lean b/SciLean/Analysis/Calculus/Monad/StateT.lean index 07f4ec5b..398b6d8a 100644 --- a/SciLean/Analysis/Calculus/Monad/StateT.lean +++ b/SciLean/Analysis/Calculus/Monad/StateT.lean @@ -1,167 +1,125 @@ -import SciLean.Analysis.Calculus.Monad.FwdCDerivMonad -import SciLean.Analysis.Calculus.Monad.RevCDerivMonad +import SciLean.Analysis.Calculus.Monad.FwdFDerivMonad +import SciLean.Analysis.Calculus.Monad.RevFDerivMonad namespace SciLean -section FwdCDerivMonad +section FwdFDerivMonad variable - {K : Type _} [RCLike K] - {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] [FwdCDerivMonad K m m'] + {K : Type} [RCLike K] + {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] + [DifferentiableMonad K m] [FwdFDerivMonad K m m'] [LawfulMonad m] [LawfulMonad m'] +noncomputable +instance (S : Type) [NormedAddCommGroup S] [NormedSpace K S] : + DifferentiableMonad K (StateT S m) where + + DifferentiableM f := DifferentiableM K (fun (xs : _×S) => f xs.1 xs.2) + + DifferentiableM_pure f hf := + by + simp [Pure.pure,StateT.pure] + fun_prop + + DifferentiableM_bind f g hf hg := + by + simp; simp at hf; simp at hg + simp[bind, StateT.bind, StateT.bind.match_1] + fun_prop + + DifferentiableM_pair f hf := + by + simp; simp at hf + simp[bind, StateT.bind, StateT.bind.match_1, pure, StateT.pure] + fun_prop + noncomputable -instance (S : Type _) [Vec K S] : FwdCDerivMonad K (StateT S m) (StateT (S×S) m') where - fwdCDerivM f := fun x dx sds => do +instance (S : Type) [NormedAddCommGroup S] [NormedSpace K S] : + FwdFDerivMonad K (StateT S m) (StateT (S×S) m') where + fwdFDerivM f := fun x dx sds => do -- ((y,s'),(dy,ds')) - let r ← fwdCDerivM K (fun (xs : _×S) => f xs.1 xs.2) (x,sds.1) (dx,sds.2) + let r ← fwdFDerivM K (fun (xs : _×S) => f xs.1 xs.2) (x,sds.1) (dx,sds.2) -- ((y,dy),(s',ds')) pure ((r.1.1,r.2.1),(r.1.2, r.2.2)) - CDifferentiableM f := CDifferentiableM K (fun (xs : _×S) => f xs.1 xs.2) - fwdCDerivM_pure f h := + fwdFDerivM_pure f h := by funext - simp[pure, StateT.pure, fwdCDeriv] + simp[pure, StateT.pure, fwdFDeriv] fun_trans - simp [fwdCDeriv] + simp [fwdFDeriv] - fwdCDerivM_bind f g hf hg := + fwdFDerivM_bind f g hf hg := by funext x dx sds - simp at hf; simp at hg - simp[fwdCDeriv, bind, StateT.bind, StateT.bind.match_1] + simp [DifferentiableM] at hf; simp [DifferentiableM] at hg + simp[fwdFDeriv, bind, StateT.bind, StateT.bind.match_1] fun_trans - fwdCDerivM_pair f hf := + fwdFDerivM_pair f hf := by funext x dx sds - simp at hf + simp [DifferentiableM] at hf simp[bind, StateT.bind, StateT.bind.match_1, pure, StateT.pure] fun_trans only simp - CDifferentiableM_pure f hf := - by - simp [Pure.pure,StateT.pure] - fun_prop - - CDifferentiableM_bind f g hf hg := - by - simp; simp at hf; simp at hg - simp[bind, StateT.bind, StateT.bind.match_1] - fun_prop - - CDifferentiableM_pair f hf := - by - simp; simp at hf - simp[bind, StateT.bind, StateT.bind.match_1, pure, StateT.pure] - fun_prop - - variable - {S : Type _} [Vec K S] - {X : Type _} [Vec K X] - {Y : Type _} [Vec K Y] - {Z : Type _} [Vec K Z] + {S : Type} [NormedAddCommGroup S] [NormedSpace K S] + {X : Type} [NormedAddCommGroup X] [NormedSpace K X] + {Y : Type} [NormedAddCommGroup Y] [NormedSpace K Y] + {Z : Type} [NormedAddCommGroup Z] [NormedSpace K Z] -- getThe ---------------------------------------------------------------------- -------------------------------------------------------------------------------- - @[simp, simp_core] -theorem _root_.getThe.arg.CDifferentiableValM_rule - : CDifferentiableValM K (m:=StateT S m) (getThe S) := +theorem _root_.getThe.arg.DifferentiableValM_rule + : DifferentiableValM K (m:=StateT S m) (getThe S) := by - simp[getThe, MonadStateOf.get, StateT.get,CDifferentiableValM,CDifferentiableM] + simp[getThe, MonadStateOf.get, StateT.get,DifferentiableValM,DifferentiableM] fun_prop @[simp, simp_core] -theorem _root_.getThe.arg.fwdCDerivValM_rule - : fwdCDerivValM K (m:=StateT S m) (getThe S) +theorem _root_.getThe.arg.fwdFDerivValM_rule + : fwdFDerivValM K (m:=StateT S m) (getThe S) = getThe (S×S) := by funext - simp[getThe, MonadStateOf.get, StateT.get,fwdCDerivValM, fwdCDerivM, pure, StateT.pure] + simp[getThe, MonadStateOf.get, StateT.get,fwdFDerivValM, fwdFDerivM, pure, StateT.pure] fun_trans --- MonadState.get -------------------------------------------------------------- --------------------------------------------------------------------------------- - - -@[simp,simp_core] -theorem _root_.MonadState.get.arg.CDifferentiableValM_rule - : CDifferentiableValM K (m:=StateT S m) (get) := -by - simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,CDifferentiableValM,CDifferentiableM] - fun_prop - -@[simp, simp_core] -theorem _root_.MonadState.get.arg.fwdCDerivValM_rule - : fwdCDerivValM K (m:=StateT S m) (get) - = - get := -by - funext - simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,fwdCDerivValM, fwdCDerivM] - fun_trans - --- -- setThe ---------------------------------------------------------------------- --- -------------------------------------------------------------------------------- - --- @[fun_prop] --- theorem _root_.setThe.arg_s.CDifferentiableM_rule --- (s : X → S) (ha0 : CDifferentiable K s) --- : CDifferentiableM K (m:=StateT S m) (fun x => setThe S (s x)) := --- by --- simp[setThe, set, StateT.set, CDifferentiableValM, CDifferentiableM] --- fun_prop - - --- @[fun_trans] --- theorem _root_.setThe.arg_s.fwdCDerivM_rule --- (s : X → S) (hs : CDifferentiable K s) --- : fwdCDerivM K (m:=StateT S m) (fun x => setThe S (s x)) --- = --- (fun x dx => do --- let sds := fwdCDeriv K s x dx --- setThe _ sds --- pure ((),())) := --- by --- simp[setThe, set, StateT.set,fwdCDerivM,bind,Bind.bind, StateT.bind] --- fun_trans; congr - - -- MonadStateOf.set ------------------------------------------------------------ -------------------------------------------------------------------------------- @[fun_prop] -theorem _root_.MonadStateOf.set.arg_a0.CDifferentiableM_rule - (s : X → S) (ha0 : CDifferentiable K s) - : CDifferentiableM K (m:=StateT S m) (fun x => set (s x)) := +theorem _root_.MonadStateOf.set.arg_a0.DifferentiableM_rule + (s : X → S) (ha0 : Differentiable K s) + : DifferentiableM K (m:=StateT S m) (fun x => set (s x)) := by - simp[set, StateT.set, CDifferentiableValM, CDifferentiableM] + simp[set, StateT.set, DifferentiableValM, DifferentiableM] fun_prop @[fun_trans] -theorem _root_.MonadStateOf.set.arg_a0.fwdCDerivM_rule - (s : X → S) (ha0 : CDifferentiable K s) - : fwdCDerivM K (m:=StateT S m) (fun x => set (s x)) +theorem _root_.MonadStateOf.set.arg_a0.fwdFDerivM_rule + (s : X → S) (ha0 : Differentiable K s) + : fwdFDerivM K (m:=StateT S m) (fun x => set (s x)) = (fun x dx => do - let sds := fwdCDeriv K s x dx + let sds := fwdFDeriv K s x dx set sds pure ((),())) := by funext - simp[set, StateT.set,fwdCDerivM, bind,Bind.bind, StateT.bind] + simp[set, StateT.set,fwdFDerivM, bind,Bind.bind, StateT.bind] fun_trans; congr @@ -169,27 +127,27 @@ by -------------------------------------------------------------------------------- @[fun_prop] -theorem _root_.modifyThe.arg_f.CDifferentiableM_rule - (f : X → S → S) (ha0 : CDifferentiable K (fun xs : X×S => f xs.1 xs.2)) - : CDifferentiableM K (m:=StateT S m) (fun x => modifyThe S (f x)) := +theorem _root_.modifyThe.arg_f.DifferentiableM_rule + (f : X → S → S) (ha0 : Differentiable K (fun xs : X×S => f xs.1 xs.2)) + : DifferentiableM K (m:=StateT S m) (fun x => modifyThe S (f x)) := by - simp[modifyThe, MonadStateOf.modifyGet, StateT.modifyGet, CDifferentiableValM, CDifferentiableM] + simp[modifyThe, MonadStateOf.modifyGet, StateT.modifyGet, DifferentiableValM, DifferentiableM] fun_prop @[fun_trans] -theorem _root_.modifyThe.arg_f.fwdCDerivM_rule - (f : X → S → S) (ha0 : CDifferentiable K (fun xs : X×S => f xs.1 xs.2)) - : fwdCDerivM K (m:=StateT S m) (fun x => modifyThe S (f x)) +theorem _root_.modifyThe.arg_f.fwdFDerivM_rule + (f : X → S → S) (ha0 : Differentiable K (fun xs : X×S => f xs.1 xs.2)) + : fwdFDerivM K (m:=StateT S m) (fun x => modifyThe S (f x)) = (fun x dx => do modifyThe (S×S) (fun sds => - let sds := fwdCDeriv K (fun xs : X×S => f xs.1 xs.2) (x,sds.1) (dx,sds.2) + let sds := fwdFDeriv K (fun xs : X×S => f xs.1 xs.2) (x,sds.1) (dx,sds.2) sds) pure ((),())) := by funext - simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdCDerivM,bind,Bind.bind, StateT.bind] + simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdFDerivM,bind,Bind.bind, StateT.bind] fun_trans; congr @@ -197,190 +155,119 @@ by -------------------------------------------------------------------------------- @[fun_prop] -theorem _root_.modify.arg_f.CDifferentiableM_rule - (f : X → S → S) (ha0 : CDifferentiable K (fun xs : X×S => f xs.1 xs.2)) - : CDifferentiableM K (m:=StateT S m) (fun x => modify (f x)) := +theorem _root_.modify.arg_f.DifferentiableM_rule + (f : X → S → S) (ha0 : Differentiable K (fun xs : X×S => f xs.1 xs.2)) + : DifferentiableM K (m:=StateT S m) (fun x => modify (f x)) := by - simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, CDifferentiableValM, CDifferentiableM] + simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, DifferentiableValM, DifferentiableM] fun_prop @[fun_trans] -theorem _root_.modify.arg_f.fwdCDerivM_rule - (f : X → S → S) (ha0 : CDifferentiable K (fun xs : X×S => f xs.1 xs.2)) - : fwdCDerivM K (m:=StateT S m) (fun x => modify (f x)) +theorem _root_.modify.arg_f.fwdFDerivM_rule + (f : X → S → S) (ha0 : Differentiable K (fun xs : X×S => f xs.1 xs.2)) + : fwdFDerivM K (m:=StateT S m) (fun x => modify (f x)) = (fun x dx => do modify (fun sds => - let sds := fwdCDeriv K (fun xs : X×S => f xs.1 xs.2) (x,sds.1) (dx,sds.2) + let sds := fwdFDeriv K (fun xs : X×S => f xs.1 xs.2) (x,sds.1) (dx,sds.2) sds) pure ((),())) := by funext - simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdCDerivM,bind,Bind.bind, StateT.bind] + simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,fwdFDerivM,bind,Bind.bind, StateT.bind] fun_trans; congr -end FwdCDerivMonad +end FwdFDerivMonad + -section RevCDerivMonad +section RevFDerivMonad variable {K : Type _} [RCLike K] - {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] [RevCDerivMonad K m m'] + {m : Type → Type} {m' : outParam $ Type → Type} [Monad m] [Monad m'] + [DifferentiableMonad K m] [RevFDerivMonad K m m'] [LawfulMonad m] [LawfulMonad m'] noncomputable -instance (S : Type _) [SemiInnerProductSpace K S] : RevCDerivMonad K (StateT S m) (StateT S m') where - revCDerivM f := fun x s => do - let ysdf ← revCDerivM K (fun (xs : _×S) => f xs.1 xs.2) (x,s) +instance (S : Type) [NormedAddCommGroup S] [AdjointSpace K S] [CompleteSpace S] : + RevFDerivMonad K (StateT S m) (StateT S m') where + revFDerivM f := fun x s => do + let ysdf ← revFDerivM K (fun (xs : _×S) => f xs.1 xs.2) (x,s) pure ((ysdf.1.1, fun dx ds => ysdf.2 (dx,ds)), ysdf.1.2) - HasAdjDiffM f := HasAdjDiffM K (fun (xs : _×S) => f xs.1 xs.2) - - revCDerivM_pure f h := + revFDerivM_pure f h := by funext - simp[pure, StateT.pure, revCDeriv] + simp[pure, StateT.pure, revFDeriv] fun_trans - simp [revCDeriv]; rfl + simp [revFDeriv]; rfl - revCDerivM_bind f g hf hg := + revFDerivM_bind f g hf hg := by funext x s - simp at hf; simp at hg - simp[revCDeriv, bind, StateT.bind, StateT.bind.match_1, StateT.pure, pure] + simp [DifferentiableM] at hf; simp [DifferentiableM] at hg + simp[revFDeriv, bind, StateT.bind, StateT.bind.match_1, StateT.pure, pure] fun_trans rfl - revCDerivM_pair f hf := + revFDerivM_pair f hf := by funext x s - simp at hf + simp [DifferentiableM] at hf simp[bind, StateT.bind, StateT.bind.match_1, pure, StateT.pure] fun_trans only simp congr; funext ysdf; congr; funext dx ds; congr; funext (dx,ds); simp; rfl - HasAdjDiffM_pure f hf := - by - simp [pure,StateT.pure] - fun_prop - - HasAdjDiffM_bind f g hf hg := - by - simp; simp at hf; simp at hg - simp[bind, StateT.bind, StateT.bind.match_1] - fun_prop - - HasAdjDiffM_pair f hf := - by - simp; simp at hf - simp[bind, StateT.bind, StateT.bind.match_1, pure, StateT.pure] - fun_prop variable - {S : Type _} [SemiInnerProductSpace K S] - {X : Type _} [SemiInnerProductSpace K X] - {Y : Type _} [SemiInnerProductSpace K Y] - {Z : Type _} [SemiInnerProductSpace K Z] + {S : Type} [NormedAddCommGroup S] [AdjointSpace K S] [CompleteSpace S] + {X : Type} [NormedAddCommGroup X] [AdjointSpace K X] [CompleteSpace X] + {Y : Type} [NormedAddCommGroup Y] [AdjointSpace K Y] [CompleteSpace Y] + {Z : Type} [NormedAddCommGroup Z] [AdjointSpace K Z] [CompleteSpace Z] -- getThe ---------------------------------------------------------------------- -------------------------------------------------------------------------------- - -@[simp, simp_core] -theorem _root_.getThe.arg.HasAdjDiffValM_rule - : HasAdjDiffValM K (m:=StateT S m) (getThe S) := -by - simp[getThe, MonadStateOf.get, StateT.get,HasAdjDiffValM,HasAdjDiffM] - fun_prop - - @[simp, simp_core] -theorem _root_.getThe.arg.revCDerivValM_rule - : revCDerivValM K (m:=StateT S m) (getThe S) +theorem _root_.getThe.arg.revFDerivValM_rule + : revFDerivValM K (m:=StateT S m) (getThe S) = (do pure ((← getThe S), fun ds => modifyThe S (fun ds' => ds + ds'))) := by funext - simp[getThe, MonadStateOf.get, StateT.get,revCDerivValM, revCDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet] + simp[getThe, MonadStateOf.get, StateT.get,revFDerivValM, revFDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet] fun_trans; rfl -- MonadState.get -------------------------------------------------------------- -------------------------------------------------------------------------------- - @[simp, simp_core] -theorem _root_.MonadState.get.arg.HasAdjDiffValM_rule - : HasAdjDiffValM K (m:=StateT S m) (get) := -by - simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,HasAdjDiffValM,HasAdjDiffM] - fun_prop - - -@[simp, simp_core] -theorem _root_.MonadState.get.arg.revCDerivValM_rule - : revCDerivValM K (m:=StateT S m) (get) +theorem _root_.MonadState.get.arg.revFDerivValM_rule + : revFDerivValM K (m:=StateT S m) (get) = (do pure ((← get), fun ds => modify (fun ds' => ds + ds'))) := by funext - simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,revCDerivValM, revCDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet, modifyGet] + simp[MonadState.get, getThe, MonadStateOf.get, StateT.get,revFDerivValM, revFDerivM, pure, StateT.pure, bind, StateT.bind, set, StateT.set, modifyThe, modify, MonadStateOf.modifyGet, StateT.modifyGet, modifyGet] fun_trans; rfl --- -- setThe ---------------------------------------------------------------------- --- -------------------------------------------------------------------------------- - --- @[fun_prop] --- theorem _root_.setThe.arg_s.HasAdjDiffM_rule --- (s : X → S) (ha0 : HasAdjDiff K s) --- : HasAdjDiffM K (m:=StateT S m) (fun x => setThe S (s x)) := --- by --- simp[setThe, set, StateT.set, HasAdjDiffValM, HasAdjDiffM] --- fun_prop - - --- @[fun_trans] --- theorem _root_.setThe.arg_s.revCDerivM_rule --- (s : X → S) (hs : HasAdjDiff K s) --- : revCDerivM K (m:=StateT S m) (fun x => setThe S (s x)) --- = --- (fun x => do --- let sds := revCDeriv K s x --- pure (← setThe S sds.1, --- fun _ => do --- let dx := sds.2 (← getThe S) --- setThe S 0 --- pure dx)) := --- by --- simp[setThe, set, StateT.set, revCDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure] --- fun_trans - - -- MonadStateOf.set ------------------------------------------------------------ -------------------------------------------------------------------------------- -@[fun_prop] -theorem _root_.MonadStateOf.set.arg_a0.HasAdjDiffM_rule - (s : X → S) (ha0 : HasAdjDiff K s) - : HasAdjDiffM K (m:=StateT S m) (fun x => set (s x)) := -by - simp[set, StateT.set, HasAdjDiffValM, HasAdjDiffM] - fun_prop - - @[fun_trans] -theorem _root_.MonadStateOf.set.arg_a0.revCDerivM_rule - (s : X → S) (ha0 : HasAdjDiff K s) - : revCDerivM K (m:=StateT S m) (fun x => set (s x)) +theorem _root_.MonadStateOf.set.arg_a0.revFDerivM_rule + (s : X → S) (ha0 : Differentiable K s) + : revFDerivM K (m:=StateT S m) (fun x => set (s x)) = (fun x => do - let sds := revCDeriv K s x + let sds := revFDeriv K s x pure (← set sds.1, fun _ => do let dx := sds.2 (← get) @@ -388,7 +275,7 @@ theorem _root_.MonadStateOf.set.arg_a0.revCDerivM_rule pure dx)) := by funext - simp[set, StateT.set, revCDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure, get] + simp[set, StateT.set, revFDerivM, getThe, MonadStateOf.get, StateT.get, bind, StateT.bind, pure, StateT.pure, get] fun_trans; congr; funext; simp[StateT.get, StateT.bind,StateT.set,StateT.pure] -- -- modifyThe ---------------------------------------------------------------------- @@ -404,12 +291,12 @@ by -- @[fun_trans] --- theorem _root_.modifyThe.arg_f.revCDerivM_rule +-- theorem _root_.modifyThe.arg_f.revFDerivM_rule -- (f : X → S → S) (ha0 : HasAdjDiff K (fun xs : X×S => f xs.1 xs.2)) --- : revCDerivM K (m:=StateT S m) (fun x => modifyThe S (f x)) +-- : revFDerivM K (m:=StateT S m) (fun x => modifyThe S (f x)) -- = -- (fun x => do --- let sdf := revCDeriv K (fun xs : X×S => f xs.1 xs.2) (x, ← getThe S) +-- let sdf := revFDeriv K (fun xs : X×S => f xs.1 xs.2) (x, ← getThe S) -- setThe S sdf.1 -- pure ((), -- fun _ => do @@ -417,29 +304,20 @@ by -- setThe S dxs.2 -- pure dxs.1)) := -- by --- simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revCDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, setThe, set, StateT.set] +-- simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revFDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, setThe, set, StateT.set] -- fun_trans; congr -- modify ---------------------------------------------------------------------- -------------------------------------------------------------------------------- -@[fun_prop] -theorem _root_.modify.arg_f.HasAdjDiffM_rule - (f : X → S → S) (ha0 : HasAdjDiff K (fun xs : X×S => f xs.1 xs.2)) - : HasAdjDiffM K (m:=StateT S m) (fun x => modify (f x)) := -by - simp[modify, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, HasAdjDiffValM, HasAdjDiffM] - fun_prop - - @[fun_trans] -theorem _root_.modify.arg_f.revCDerivM_rule - (f : X → S → S) (ha0 : HasAdjDiff K (fun xs : X×S => f xs.1 xs.2)) - : revCDerivM K (m:=StateT S m) (fun x => modify (f x)) +theorem _root_.modify.arg_f.revFDerivM_rule + (f : X → S → S) (ha0 : Differentiable K (fun xs : X×S => f xs.1 xs.2)) + : revFDerivM K (m:=StateT S m) (fun x => modify (f x)) = (fun x => do - let sdf := revCDeriv K (fun xs : X×S => f xs.1 xs.2) (x, ← get) + let sdf := revFDeriv K (fun xs : X×S => f xs.1 xs.2) (x, ← get) set sdf.1 pure ((), fun _ => do @@ -448,7 +326,7 @@ theorem _root_.modify.arg_f.revCDerivM_rule pure dxs.1)) := by funext - simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revCDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, set, StateT.set, get, pure, StateT.pure, modify] + simp[modifyThe, modifyGet, MonadStateOf.modifyGet, StateT.modifyGet,revFDerivM, bind, StateT.bind, getThe, MonadStateOf.get, StateT.get, set, StateT.set, get, pure, StateT.pure, modify] fun_trans; congr; funext; simp[StateT.bind,StateT.pure,StateT.get,StateT.set] -end RevCDerivMonad +end RevFDerivMonad