Skip to content

Commit

Permalink
new ftrans revDerivProj and revDerivProjUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 20, 2023
1 parent cd47051 commit b7892a0
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 0 deletions.
130 changes: 130 additions & 0 deletions SciLean/Core/FunctionTransformations/RevDerivProj.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import SciLean.Core.FunctionTransformations.RevCDeriv
import SciLean.Core.FunctionTransformations.RevDerivUpdate
import SciLean.Data.TypeWithProj

set_option linter.unusedVariables false

namespace SciLean

variable
(K : Type _) [IsROrC K]
{X : Type _} [SemiInnerProductSpace K X]
{Y : Type _} [SemiInnerProductSpace K Y]
{Z : Type _} [SemiInnerProductSpace K Z]
{W : Type _} [SemiInnerProductSpace K W]
{ι : Type _} [EnumType ι]

{E : Type _} [SemiInnerProductSpace K E]
{EIdx : Type _}
{EVal : EIdx → Type _} [∀ i, SemiInnerProductSpace K (EVal i)]
[TypeWithProj E EIdx EVal]
{F : Type _} [SemiInnerProductSpace K F]
{FIdx : Type _}
{FVal : FIdx → Type _} [∀ i, SemiInnerProductSpace K (FVal i)]
[TypeWithProj F FIdx FVal]

instance (i : EIdx ⊕ FIdx) : Vec K (Prod.TypeFun EVal FVal i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

instance (i : EIdx ⊕ FIdx) : SemiInnerProductSpace K (Prod.TypeFun EVal FVal i) :=
match i with
| .inl _ => by infer_instance
| .inr _ => by infer_instance

noncomputable
def revDerivProj
(f : X → E) (x : X) : E×((i : EIdx)→EVal i→X) :=
(f x, fun i de =>
have := Classical.propDecidable
(revCDeriv K f x).2 (TypeWithProj.intro fun i' => if h:i=i' then h▸de else 0))

noncomputable
def revDerivProjUpdate
(f : X → E) (x : X) : E×((i : EIdx)→EVal i→K→X→X) :=
(f x, fun i de k dx =>
have := Classical.propDecidable
(revDerivUpdate K f x).2 (TypeWithProj.intro fun i' => if h:i=i' then h▸de else 0) k x)


--------------------------------------------------------------------------------


theorem revDerivProj.id_rule
: revDerivProj K (fun x : E => x)
=
fun x =>
(x,
fun i de =>
have := Classical.propDecidable
TypeWithProj.intro fun i' => if h : i=i' then h▸de else 0):=
by
simp[revDerivProj]
ftrans

theorem revDerivProjUpdate.id_rule
: revDerivProjUpdate K (fun x : E => x)
=
fun x =>
(x,
fun i de k dx =>
TypeWithProj.modify i (fun ei => ei + k•de) dx) :=
by
simp[revDerivProjUpdate]
ftrans
sorry_proof


theorem revDerivProj.comp_rule
(f : Y → E) (g : X → Y)
: revDerivProj K (fun x => f (g x))
=
fun x =>
let ydg' := revCDeriv K g x
let zdf' := revDerivProj K f ydg'.1
(zdf'.1,
fun i de =>
ydg'.2 (zdf'.2 i de)) :=
by
sorry_proof

theorem revDerivProjUpdate.comp_rule
(f : Y → E) (g : X → Y)
: revDerivProjUpdate K (fun x => f (g x))
=
fun x =>
let ydg' := revDerivUpdate K g x
let zdf' := revDerivProj K f ydg'.1
(zdf'.1,
fun i de k dx =>
ydg'.2 (zdf'.2 i de) k dx) :=
by
sorry_proof



theorem Prod.fst.arg_self.revDeriv_rule
(f : W → X×Y) (hf : HasAdjDiff K f)
: revCDeriv K (fun w => (f w).1)
=
fun w =>
let xydf' := revDerivProj K f w
(xydf'.1.1, fun dx => xydf'.2 (.inl ()) dx) :=
by
sorry_proof


theorem Prod.fst.arg_self.revDerivProj_rule
{XIdx : Type _}
{XVal : XIdx → Type _} [∀ i, SemiInnerProductSpace K (XVal i)]
[TypeWithProj X XIdx XVal]
(f : W → X×Y) (hf : HasAdjDiff K f)
: revDerivProj K (fun w => (f w).1)
=
fun w =>
let xydf' := revDerivProj K f w
(xydf'.1.1,
fun i dx => xydf'.2 (.inl i) dx) :=
by
sorry_proof
47 changes: 47 additions & 0 deletions SciLean/Data/TypeWithProj.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import Mathlib.Init.Function
import SciLean.Util.SorryProof

namespace SciLean

open Function

class TypeWithProj (F : Sort _) (I : outParam (Sort _)) (E : outParam <| I → Sort _) where
proj : F → (i : I) → (E i)
intro : ((i : I) → (E i)) → F
modify : (i : I) → (E i → E i) → (F → F)
left_inv : LeftInverse proj intro
right_inv : RightInverse proj intro
-- TODO: theorem about modify



--------------------------------------------------------------------------------
-- Prod ------------------------------------------------------------------------
--------------------------------------------------------------------------------

abbrev _root_.Prod.TypeFun {αIdx βIdx : Type _} (αType : αIdx → Type _) (βType : βIdx → Type _) (i : Sum αIdx βIdx) : Type _ :=
match i with
| .inl a => αType a
| .inr b => βType b

instance (priority:=low) : TypeWithProj α Unit (fun _ => α) where
proj := fun x _ => x
intro := fun f => f ()
modify := fun _ f x => f x
left_inv := sorry_proof
right_inv := sorry_proof

instance [TypeWithProj α αIdx αType] [TypeWithProj β βIdx βType]
: TypeWithProj (α×β) (Sum αIdx βIdx) (Prod.TypeFun αType βType) where
proj := fun (x,y) i =>
match i with
| .inl a => TypeWithProj.proj x a
| .inr b => TypeWithProj.proj y b
intro := fun f => (TypeWithProj.intro (fun a => f (.inl a)),
TypeWithProj.intro (fun b => f (.inr b)))
modify := fun i f (x,y) =>
match i with
| .inl a => (TypeWithProj.modify a f x, y)
| .inr b => (x, TypeWithProj.modify b f y)
left_inv := sorry_proof
right_inv := sorry_proof

0 comments on commit b7892a0

Please sign in to comment.