Skip to content

Commit

Permalink
clean up ForIn file
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 20, 2023
1 parent c2f27cd commit dde2201
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 281 deletions.
279 changes: 3 additions & 276 deletions SciLean/Core/Monads/ForIn.lean
Original file line number Diff line number Diff line change
@@ -1,122 +1,14 @@
import SciLean.Core.Monads.FwdDerivMonad
import SciLean.Core.Monads.Id
import SciLean.Core.Monads.MProd
import SciLean.Core.Monads.ForInStep

import SciLean.Data.DataArray

set_option linter.unusedVariables false

namespace SciLean

variable
{K : Type _} [IsROrC K]

-- This is not true but lets assume it for now
instance [Vec K X] : Vec K (ForInStep X) := sorry

-- This is not true but lets assume it for now
instance [SemiInnerProductSpace K X] : SemiInnerProductSpace K (ForInStep X) := sorry



-- TODO: transport vec structure from Prod
instance [Vec K X] [Vec K Y] : Vec K (MProd X Y) := sorry
instance [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] : SemiInnerProductSpace K (MProd X Y) := sorry


@[fprop]
theorem _root_.MProd.mk.arg_fstsnd.HasAdjDiff_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → X) (g : W → Y) (hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: HasAdjDiff K (fun w => MProd.mk (f w) (g w)) := by sorry_proof


@[ftrans]
theorem _root_.MProd.mk.arg_fstsnd.revCDeriv_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → X) (g : W → Y) (hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: revCDeriv K (fun w => MProd.mk (f w) (g w))
=
fun w =>
let xdf' := revCDeriv K f w
let ydg' := revCDeriv K g w
(MProd.mk xdf'.1 ydg'.1,
fun dxy =>
xdf'.2 dxy.1 + ydg'.2 dxy.2) :=
by
sorry_proof


@[fprop]
theorem _root_.MProd.fst.arg_self.HasAdjDiff_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → MProd X Y) (hf : HasAdjDiff K f)
: HasAdjDiff K (fun w => (f w).1) := by sorry_proof

@[ftrans]
theorem _root_.MProd.fst.arg_self.revCDeriv_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → MProd X Y) (hf : HasAdjDiff K f)
: revCDeriv K (fun w => (f w).1)
=
fun w =>
let xydxy := revCDeriv K f w
(xydxy.1.1, fun dw => xydxy.2 (MProd.mk dw 0)) := by sorry_proof

@[fprop]
theorem _root_.MProd.snd.arg_self.HasAdjDiff_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → MProd X Y) (hf : HasAdjDiff K f)
: HasAdjDiff K (fun w => (f w).2) := by sorry_proof

@[ftrans]
theorem _root_.MProd.snd.arg_self.revCDeriv_rule
[SemiInnerProductSpace K W] [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y]
(f : W → MProd X Y) (hf : HasAdjDiff K f)
: revCDeriv K (fun w => (f w).2)
=
fun w =>
let xydxy := revCDeriv K f w
(xydxy.1.2, fun dw => xydxy.2 (MProd.mk 0 dw)) := by sorry_proof


end SciLean
open SciLean

def ForInStep.val : ForInStep α → α
| .yield a => a
| .done a => a


@[simp, ftrans_simp]
theorem ForInStep.val_yield (a : α) : ForInStep.val (.yield a) = a := by rfl

@[simp, ftrans_simp]
theorem ForInStep.val_done (a : α) : ForInStep.val (.done a) = a := by rfl


/-- Turns a pair of values each with yield/done annotation into a pair with
a single yield/done annotation based on the first element. -/
def ForInStep.return2 (x : ForInStep α × ForInStep β) : ForInStep (α × β) :=
match x.1, x.2 with
| .yield x₁, .yield x₂ => .yield (x₁, x₂)
| .yield x₁, .done x₂ => .yield (x₁, x₂)
| .done x₁, .yield x₂ => .done (x₁, x₂)
| .done x₁, .done x₂ => .done (x₁, x₂)

def ForInStep.return2Inv (x : ForInStep (α × β)) : ForInStep α × ForInStep β :=
match x with
| .yield (x,y) => (.yield x, .yield y)
| .done (x,y) => (.done x, .done y)


@[simp]
theorem ForInStep.return2_return2Inv_yield {α β} (x : α × β)
: ForInStep.return2 (ForInStep.return2Inv (.yield x)) = .yield x := by rfl

@[simp]
theorem ForInStep.return2_return2Inv_done {α β} (x : α × β)
: ForInStep.return2 (ForInStep.return2Inv (.done x)) = .done x := by rfl


section OnVec

variable
Expand Down Expand Up @@ -173,65 +65,6 @@ by
ftrans

--------------------------------------------------------------------------------
-- ForInStep.yield -------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.yield.arg_a0.IsDifferentiable_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: IsDifferentiable K fun x => ForInStep.yield (a0 x) := by sorry_proof

-- this is a hack as with Id monad sometimes you do not need `pure` which trips `fprop` up
@[fprop]
theorem ForInStep.yield.arg_a0.IsDifferentiableM_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: IsDifferentiableM (m:=Id) K fun x => ForInStep.yield (a0 x) := by sorry_proof

@[ftrans]
theorem ForInStep.yield.arg_a0.cderiv_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: cderiv K (fun x => ForInStep.yield (a0 x))
=
fun x dx => ForInStep.yield (cderiv K a0 x dx) := by sorry_proof

@[ftrans]
theorem ForInStep.yield.arg_a0.fwdCDeriv_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: fwdCDeriv K (fun x => ForInStep.yield (a0 x))
=
fun x dx => ForInStep.return2Inv (ForInStep.yield (fwdCDeriv K a0 x dx))
:= by sorry_proof

@[fprop]
theorem ForInStep.done.arg_a0.IsDifferentiable_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: IsDifferentiable K fun x => ForInStep.done (a0 x) := by sorry_proof


--------------------------------------------------------------------------------
-- ForInStep.done -------------------------------------------------------------
--------------------------------------------------------------------------------

-- this is a hack as with Id monad sometimes you do not need `pure` which trips `fprop` up
@[fprop]
theorem ForInStep.done.arg_a0.IsDifferentiableM_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: IsDifferentiableM (m:=Id) K fun x => ForInStep.done (a0 x) := by sorry_proof

@[ftrans]
theorem ForInStep.done.arg_a0.cderiv_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: cderiv K (fun x => ForInStep.done (a0 x))
=
fun x dx => ForInStep.done (cderiv K a0 x dx) := by sorry_proof

@[ftrans]
theorem ForInStep.done.arg_a0.fwdCDeriv_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: fwdCDeriv K (fun x => ForInStep.done (a0 x))
=
fun x dx => ForInStep.return2Inv (ForInStep.done (fwdCDeriv K a0 x dx))
:= by sorry_proof


end OnVec
Expand Down Expand Up @@ -435,112 +268,6 @@ by
sorry_proof



--------------------------------------------------------------------------------
-- ForInStep.yield -------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.yield.arg_a0.HasAdjDiff_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiff K fun x => ForInStep.yield (a0 x) := by sorry_proof

@[fprop]
theorem ForInStep.yield.arg_a0.HasAdjDiffM_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiffM (m:=Id) K fun x => ForInStep.yield (a0 x) := by sorry_proof

@[ftrans]
theorem ForInStep.yield.arg_a0.revCDeriv_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: revCDeriv K (fun x => ForInStep.yield (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(.yield ydf.1, fun y => ydf.2 y.val)
:= by sorry_proof

@[ftrans]
theorem ForInStep.yield.arg_a0.revDerivM_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: revDerivM (m:=Id) K (fun x => ForInStep.yield (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(.yield ydf.1, fun y => ydf.2 y.val)
:= by sorry_proof


--------------------------------------------------------------------------------
-- ForInStep.done --------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.done.arg_a0.HasAdjDiff_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiff K fun x => ForInStep.done (a0 x) := by sorry_proof

@[fprop]
theorem ForInStep.done.arg_a0.HasAdjDiffM_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiffM (m:=Id) K fun x => ForInStep.done (a0 x) := by sorry_proof

@[ftrans]
theorem ForInStep.done.arg_a0.revCDeriv_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: revCDeriv K (fun x => ForInStep.done (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(.done ydf.1, fun y => ydf.2 y.val)
:= by sorry_proof

@[ftrans]
theorem ForInStep.done.arg_a0.revDerivM_rule
(a0 : X → Y) (ha0 : HasAdjDiff K a0)
: revDerivM (m:=Id) K (fun x => ForInStep.done (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(.done ydf.1, fun y => ydf.2 y.val)
:= by sorry_proof


--------------------------------------------------------------------------------
-- ForInStep.val --------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.val.arg_a0.HasAdjDiff_rule
(a0 : X → ForInStep Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiff K fun x => ForInStep.val (a0 x) := by sorry_proof

@[fprop]
theorem ForInStep.val.arg_a0.HasAdjDiffM_rule
(a0 : X → ForInStep Y) (ha0 : HasAdjDiff K a0)
: HasAdjDiffM (m:=Id) K fun x => ForInStep.val (a0 x) := by sorry_proof

@[ftrans]
theorem ForInStep.val.arg_a0.revCDeriv_rule
(a0 : X → ForInStep Y) (ha0 : HasAdjDiff K a0)
: revCDeriv K (fun x => ForInStep.val (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(ydf.1.val, fun y => ydf.2 (.yield y))
:= by sorry_proof

@[ftrans]
theorem ForInStep.val.arg_a0.revDerivM_rule
(a0 : X → ForInStep Y) (ha0 : HasAdjDiff K a0)
: revDerivM (m:=Id) K (fun x => ForInStep.val (a0 x))
=
fun x =>
let ydf := revCDeriv K a0 x
(ydf.1.val, fun y => ydf.2 (.yield y))
:= by sorry_proof


end OnSemiInnerProductSpace


Expand Down
35 changes: 30 additions & 5 deletions SciLean/Core/Monads/MProd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ theorem MProd.mk.arg_fstsnd.IsDifferentiable_rule
(f : W → X) (g : W → Y) (hf : IsDifferentiable K f) (hg : IsDifferentiable K g)
: IsDifferentiable K (fun w => MProd.mk (f w) (g w)) := by sorry_proof

@[ftrans]
theorem MProd.mk.arg_fstsnd.cderiv_rule
(f : W → X) (g : W → Y) (hf : IsDifferentiable K f) (hg : IsDifferentiable K g)
: cderiv K (fun w => MProd.mk (f w) (g w))
=
fun w dw =>
let dx := cderiv K f w dw
let dy := cderiv K g w dw
⟨dx,dy⟩ :=
by
sorry_proof

@[ftrans]
theorem MProd.mk.arg_fstsnd.fwdCDeriv_rule
Expand All @@ -37,8 +48,7 @@ theorem MProd.mk.arg_fstsnd.fwdCDeriv_rule
let ydy := fwdCDeriv K g w dw
(⟨xdx.1,ydy.1⟩, ⟨xdx.2,ydy.2⟩) :=
by
sorry_proof

unfold fwdCDeriv; ftrans

@[fprop]
theorem MProd.fst.arg_self.IsDifferentiable_rule
Expand Down Expand Up @@ -80,12 +90,28 @@ variable
{W : Type _} [SemiInnerProductSpace K W]

-- TODO: transport structure from Prod
instance [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] : SemiInnerProductSpace K (MProd X Y) := sorry
instance [Inner K X] [Inner K Y] : Inner K (MProd X Y) :=
fun ⟨x,y⟩ ⟨x',y'⟩ => Inner.inner x x' + Inner.inner y y'⟩

instance [TestFunctions X] [TestFunctions Y] : TestFunctions (MProd X Y) where
TestFunction := fun ⟨x,y⟩ => TestFunction x ∧ TestFunction y

instance [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] : SemiInnerProductSpace K (MProd X Y) := SemiInnerProductSpace.mkSorryProofs

@[fprop]
theorem MProd.mk.arg_fstsnd.HasSemiAdjoint_rule
(f : W → X) (g : W → Y) (hf : HasSemiAdjoint K f) (hg : HasSemiAdjoint K g)
: HasSemiAdjoint K (fun w => MProd.mk (f w) (g w)) := by sorry_proof


@[fprop]
theorem MProd.mk.arg_fstsnd.HasAdjDiff_rule
(f : W → X) (g : W → Y) (hf : HasAdjDiff K f) (hg : HasAdjDiff K g)
: HasAdjDiff K (fun w => MProd.mk (f w) (g w)) := by sorry_proof
: HasAdjDiff K (fun w => MProd.mk (f w) (g w)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
constructor; fprop; ftrans; fprop


@[ftrans]
Expand All @@ -102,7 +128,6 @@ theorem MProd.mk.arg_fstsnd.revCDeriv_rule
by
sorry_proof


@[fprop]
theorem MProd.fst.arg_self.HasAdjDiff_rule
(f : W → MProd X Y) (hf : HasAdjDiff K f)
Expand Down

0 comments on commit dde2201

Please sign in to comment.