diff --git a/SciLean/Core/FunctionTransformations/RevDerivProj.lean b/SciLean/Core/FunctionTransformations/RevDerivProj.lean index eb3797df..b287249e 100644 --- a/SciLean/Core/FunctionTransformations/RevDerivProj.lean +++ b/SciLean/Core/FunctionTransformations/RevDerivProj.lean @@ -1,6 +1,6 @@ import SciLean.Core.FunctionTransformations.RevCDeriv import SciLean.Core.FunctionTransformations.RevDerivUpdate -import SciLean.Data.TypeWithProj +import SciLean.Data.StructLike set_option linter.unusedVariables false @@ -13,44 +13,33 @@ variable {Z : Type _} [SemiInnerProductSpace K Z] {W : Type _} [SemiInnerProductSpace K W] {ι : Type _} [EnumType ι] - - {E : Type _} [SemiInnerProductSpace K E] - {EIdx : Type _} - {EVal : EIdx → Type _} [∀ i, SemiInnerProductSpace K (EVal i)] - [TypeWithProj E EIdx EVal] - {F : Type _} [SemiInnerProductSpace K F] - {FIdx : Type _} - {FVal : FIdx → Type _} [∀ i, SemiInnerProductSpace K (FVal i)] - [TypeWithProj F FIdx FVal] - -instance (i : EIdx ⊕ FIdx) : Vec K (Prod.TypeFun EVal FVal i) := - match i with - | .inl _ => by infer_instance - | .inr _ => by infer_instance - -instance (i : EIdx ⊕ FIdx) : SemiInnerProductSpace K (Prod.TypeFun EVal FVal i) := - match i with - | .inl _ => by infer_instance - | .inr _ => by infer_instance + {E I : Type _} {EI : I → Type _} + [StructLike E I EI] + [SemiInnerProductSpace K E] [∀ i, SemiInnerProductSpace K (EI i)] + {F J : Type _} {FJ : J → Type _} + [StructLike F J FJ] + [SemiInnerProductSpace K F] [∀ j, SemiInnerProductSpace K (FJ j)] noncomputable def revDerivProj - (f : X → E) (x : X) : E×((i : EIdx)→EVal i→X) := - (f x, fun i de => + (f : X → E) (x : X) : E×((i : I)→EI i→X) := + let ydf' := revCDeriv K f x + (ydf'.1, fun i de => have := Classical.propDecidable - (revCDeriv K f x).2 (TypeWithProj.intro fun i' => if h:i=i' then h▸de else 0)) + ydf'.2 (StructLike.make fun i' => if h:i=i' then h▸de else 0)) noncomputable def revDerivProjUpdate - (f : X → E) (x : X) : E×((i : EIdx)→EVal i→K→X→X) := - (f x, fun i de k dx => - have := Classical.propDecidable - (revDerivUpdate K f x).2 (TypeWithProj.intro fun i' => if h:i=i' then h▸de else 0) k x) + (f : X → E) (x : X) : E×((i : I)→EI i→X→X) := + let ydf' := revDerivUpdate K f x + (ydf'.1, fun i de dx => + ydf'.2 (have := Classical.propDecidable; StructLike.make fun i' => if h:i=i' then h▸de else 0) 1 dx) -------------------------------------------------------------------------------- +variable (E) theorem revDerivProj.id_rule : revDerivProj K (fun x : E => x) = @@ -58,26 +47,81 @@ theorem revDerivProj.id_rule (x, fun i de => have := Classical.propDecidable - TypeWithProj.intro fun i' => if h : i=i' then h▸de else 0):= + StructLike.make fun i' => if h : i=i' then h▸de else 0):= by simp[revDerivProj] ftrans + theorem revDerivProjUpdate.id_rule : revDerivProjUpdate K (fun x : E => x) = fun x => (x, - fun i de k dx => - TypeWithProj.modify i (fun ei => ei + k•de) dx) := + fun i de dx => + StructLike.modify i (fun ei => ei + de) dx) := by simp[revDerivProjUpdate] ftrans sorry_proof +variable {E} + +variable (Y) +theorem revDerivProj.const_rule (x : E) + : revDerivProj K (fun _ : Y => x) + = + fun _ => + (x, + fun i (de : EI i) => 0) := +by + simp[revDerivProj] + ftrans + +theorem revDerivProjUpdate.const_rule (x : E) + : revDerivProjUpdate K (fun _ : Y => x) + = + fun _ => + (x, + fun i de dx => dx) := +by + simp[revDerivProjUpdate]; + ftrans +variable {Y} + + +theorem revDerivProj.proj_rule [DecidableEq I] (i : ι) + : revDerivProj K (fun (f : ι → E) => f i) + = + fun f => + (f i, fun j dxj i' => + if i=i' then + StructLike.make fun j' => + if h : j=j' then + (h▸dxj) + else + 0 + else + 0) := +by + sorry_proof + + +theorem revDerivProjUpdate.proj_rule [DecidableEq I] (i : ι) + : revDerivProjUpdate K (fun (f : ι → E) => f i) + = + fun f => + (f i, fun j dxj df i' => + if i=i' then + StructLike.modify j (fun xj => xj + dxj) (df i') + else + df i') := +by + sorry_proof theorem revDerivProj.comp_rule (f : Y → E) (g : X → Y) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) : revDerivProj K (fun x => f (g x)) = fun x => @@ -89,21 +133,380 @@ theorem revDerivProj.comp_rule by sorry_proof + theorem revDerivProjUpdate.comp_rule (f : Y → E) (g : X → Y) + (hf : HasAdjDiff K f) (hg : HasAdjDiff K g) : revDerivProjUpdate K (fun x => f (g x)) = fun x => let ydg' := revDerivUpdate K g x let zdf' := revDerivProj K f ydg'.1 (zdf'.1, - fun i de k dx => - ydg'.2 (zdf'.2 i de) k dx) := + fun i de dx => + ydg'.2 (zdf'.2 i de) 1 dx) := +by + sorry_proof + + +theorem revDerivProj.let_rule + (f : X → Y → E) (g : X → Y) + (hf : HasAdjDiff K (fun (x,y) => f x y)) (hg : HasAdjDiff K g) + : revDerivProj K (fun x => let y := g x; f x y) + = + fun x => + let ydg' := revDerivUpdate K g x + let zdf' := revDerivProj K (fun (x,y) => f x y) (x,ydg'.1) + (zdf'.1, + fun i dei => + let dxy := zdf'.2 i dei + ydg'.2 dxy.2 1 dxy.1) := +by + sorry_proof + + +theorem revDerivProjUpdate.let_rule + (f : X → Y → E) (g : X → Y) + (hf : HasAdjDiff K (fun (x,y) => f x y)) (hg : HasAdjDiff K g) + : revDerivProjUpdate K (fun x => let y := g x; f x y) + = + fun x => + let ydg' := revDerivUpdate K g x + let zdf' := revDerivProjUpdate K (fun (x,y) => f x y) (x,ydg'.1) + (zdf'.1, + fun i dei dx => + let dxy := zdf'.2 i dei (dx,0) + ydg'.2 dxy.2 1 dxy.1) := +by + sorry_proof + + +theorem revDerivProj.pi_rule + (f : X → ι → E) (hf : ∀ i, HasAdjDiff K (f · i)) + : (revDerivProj K fun (x : X) (i : ι) => f x i) + = + revDerivProj K fun x => f x := by rfl + + +theorem revDerivProjUpdate.pi_rule + (f : X → ι → E) (hf : ∀ i, HasAdjDiff K (f · i)) + : (revDerivProjUpdate K fun (x : X) (i : ι) => f x i) + = + revDerivProjUpdate K fun x => f x := by rfl + + + +-------------------------------------------------------------------------------- + +-- Register `revDerivProj` as function transformation -------------------------- +-------------------------------------------------------------------------------- + +namespace revDerivProj + +open Lean Meta Qq in +def discharger (e : Expr) : SimpM (Option Expr) := do + withTraceNode `revDerivProj_discharger (fun _ => return s!"discharge {← ppExpr e}") do + let cache := (← get).cache + let config : FProp.Config := {} + let state : FProp.State := { cache := cache } + let (proof?, state) ← FProp.fprop e |>.run config |>.run state + modify (fun simpState => { simpState with cache := state.cache }) + if proof?.isSome then + return proof? + else + -- if `fprop` fails try assumption + let tac := FTrans.tacticToDischarge (Syntax.mkLit ``Lean.Parser.Tactic.assumption "assumption") + let proof? ← tac e + return proof? + + +open Lean Meta FTrans in +def ftransExt : FTransExt where + ftransName := ``revDerivProj + + getFTransFun? e := + if e.isAppOf ``revDerivProj then + + if let .some f := e.getArg? 10 then + some f + else + none + else + none + + replaceFTransFun e f := + if e.isAppOf ``revDerivProj then + e.setArg 6 f + else + e + + idRule e X := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``id_rule #[K, X], origin := .decl ``id_rule, rfl := false} ] + discharger e + + constRule e X y := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``const_rule #[K, X, y], origin := .decl ``const_rule, rfl := false} ] + discharger e + + projRule e X i := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``proj_rule #[K, X, i], origin := .decl ``proj_rule, rfl := false} ] + discharger e + + compRule e f g := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``comp_rule #[K, f, g], origin := .decl ``comp_rule, rfl := false} ] + discharger e + + letRule e f g := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``let_rule #[K, f, g], origin := .decl ``let_rule, rfl := false} ] + discharger e + + piRule e f := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``pi_rule #[K, f], origin := .decl ``pi_rule, rfl := false} ] + discharger e + + discharger := discharger + + +-- register revDerivProj +open Lean in +#eval show CoreM Unit from do + modifyEnv (λ env => FTrans.ftransExt.addEntry env (``revDerivProj, ftransExt)) + +end revDerivProj + + +namespace revDerivProjUpdate + +open Lean Meta Qq in +def discharger (e : Expr) : SimpM (Option Expr) := do + withTraceNode `revDerivProjUpdate_discharger (fun _ => return s!"discharge {← ppExpr e}") do + let cache := (← get).cache + let config : FProp.Config := {} + let state : FProp.State := { cache := cache } + let (proof?, state) ← FProp.fprop e |>.run config |>.run state + modify (fun simpState => { simpState with cache := state.cache }) + if proof?.isSome then + return proof? + else + -- if `fprop` fails try assumption + let tac := FTrans.tacticToDischarge (Syntax.mkLit ``Lean.Parser.Tactic.assumption "assumption") + let proof? ← tac e + return proof? + + +open Lean Meta FTrans in +def ftransExt : FTransExt where + ftransName := ``revDerivProjUpdate + + getFTransFun? e := + if e.isAppOf ``revDerivProjUpdate then + + if let .some f := e.getArg? 10 then + some f + else + none + else + none + + replaceFTransFun e f := + if e.isAppOf ``revDerivProjUpdate then + e.setArg 6 f + else + e + + idRule e X := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``id_rule #[K, X], origin := .decl ``id_rule, rfl := false} ] + discharger e + + constRule e X y := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``const_rule #[K, X, y], origin := .decl ``const_rule, rfl := false} ] + discharger e + + projRule e X i := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``proj_rule #[K, X, i], origin := .decl ``proj_rule, rfl := false} ] + discharger e + + compRule e f g := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``comp_rule #[K, f, g], origin := .decl ``comp_rule, rfl := false} ] + discharger e + + letRule e f g := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``let_rule #[K, f, g], origin := .decl ``let_rule, rfl := false} ] + discharger e + + piRule e f := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``pi_rule #[K, f], origin := .decl ``pi_rule, rfl := false} ] + discharger e + + discharger := discharger + + +-- register revDerivProjUpdate +open Lean in +#eval show CoreM Unit from do + modifyEnv (λ env => FTrans.ftransExt.addEntry env (``revDerivProjUpdate, ftransExt)) + +end revDerivProjUpdate + + + +-------------------------------------------------------------------------------- + + + +-------------------------------------------------------------------------------- +end SciLean +open SciLean + +variable + {K : Type _} [IsROrC K] + {X Xi : Type} {XI : Xi → Type} [StructLike X Xi XI] [DecidableEq Xi] + {Y Yi : Type} {YI : Yi → Type} [StructLike Y Yi YI] [DecidableEq Yi] + {Z Zi : Type} {ZI : Zi → Type} [StructLike Z Zi ZI] [DecidableEq Zi] + [SemiInnerProductSpace K X] [∀ i, SemiInnerProductSpace K (XI i)] + [SemiInnerProductSpace K Y] [∀ i, SemiInnerProductSpace K (YI i)] + [SemiInnerProductSpace K Z] [∀ i, SemiInnerProductSpace K (ZI i)] + {W : Type _} [SemiInnerProductSpace K W] + {ι : Type _} [EnumType ι] + +@[simp] +theorem StruckLike.make_zero + {X I : Type} {XI : I → Type} [StructLike X I XI] + [Zero X] [∀ i, Zero (XI i)] + : StructLike.make (E:=X) (fun _ => 0) + = + 0 := +by + sorry + +@[simp] +theorem revCDeriv_snd_zero + (f : X → Y) (x : X) + : (revCDeriv K f x).2 0 = 0 := sorry + +@[simp] +theorem revDerivUpdate_snd_zero + (f : X → Y) (x : X) + : (revDerivUpdate K f x).2 0 k dy = dy := sorry + + +-- Prod.mk --------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +@[ftrans] +theorem Prod.mk.arg_fstsnd.revDerivProj_rule + (g : X → Y) (f : X → Z) + (hg : HasAdjDiff K g) (hf : HasAdjDiff K f) + : revDerivProj K (fun x => (g x, f x)) + = + fun x => + let ydg := revDerivProj K g x + let zdf := revDerivProj K f x + ((ydg.1,zdf.1), + fun i dyz => + match i with + | .inl j => ydg.2 j dyz + | .inr j => zdf.2 j dyz) := +by + unfold revDerivProj + funext x; ftrans; simp + funext i dyz + induction i <;> + { simp[StructLike.make] + apply congr_arg + congr; funext i; congr; funext h + subst h; rfl + } + +@[ftrans] +theorem Prod.mk.arg_fstsnd.revDerivProjUpdate_rule + (g : X → Y) (f : X → Z) + (hg : HasAdjDiff K g) (hf : HasAdjDiff K f) + : revDerivProjUpdate K (fun x => (g x, f x)) + = + fun x => + let ydg := revDerivProjUpdate K g x + let zdf := revDerivProjUpdate K f x + ((ydg.1,zdf.1), + fun i dyz dx => + match i with + | .inl j => ydg.2 j dyz dx + | .inr j => zdf.2 j dyz dx) := +by + unfold revDerivProjUpdate + funext x; ftrans; simp + funext i de dx + induction i <;> + { simp[StructLike.make] + sorry_proof + } + + +------------------------ + +@[ftrans] +theorem Prod.fst.arg_self.revDerivProj_rule + (f : W → X×Y) (hf : HasAdjDiff K f) + : revDerivProj K (fun x => (f x).1) + = + fun w => + let xydf := revDerivProj K f w + (xydf.1.1, + fun i dxy => xydf.2 (.inl i) dxy) := +by + unfold revDerivProj + funext x; ftrans; simp + funext e dxy + simp[StructLike.make] + apply congr_arg + congr; funext i; congr; funext h; subst h; rfl + + +@[ftrans] +theorem Prod.fst.arg_self.revDerivProjUpdate_rule + (f : W → X×Y) (hf : HasAdjDiff K f) + : revDerivProjUpdate K (fun x => (f x).1) + = + fun w => + let xydf := revDerivProjUpdate K f w + (xydf.1.1, + fun i dxy dw => xydf.2 (.inl i) dxy dw) := by + unfold revDerivProjUpdate + funext x; ftrans; simp + funext e dxy dw + simp[StructLike.make] + -- apply congr_arg -- this fails for some reason :( + -- congr; funext i; congr; funext h; subst h; rfl sorry_proof + theorem Prod.fst.arg_self.revDeriv_rule (f : W → X×Y) (hf : HasAdjDiff K f) : revCDeriv K (fun w => (f w).1) @@ -116,9 +519,8 @@ by theorem Prod.fst.arg_self.revDerivProj_rule - {XIdx : Type _} - {XVal : XIdx → Type _} [∀ i, SemiInnerProductSpace K (XVal i)] - [TypeWithProj X XIdx XVal] + {X I : Type _} {XI : I → Type _} + [StructLike X I XI] [SemiInnerProductSpace K X] [∀ i, SemiInnerProductSpace K (XI i)] (f : W → X×Y) (hf : HasAdjDiff K f) : revDerivProj K (fun w => (f w).1) = diff --git a/SciLean/Data/StructLike.lean b/SciLean/Data/StructLike.lean new file mode 100644 index 00000000..c55f15db --- /dev/null +++ b/SciLean/Data/StructLike.lean @@ -0,0 +1,2 @@ +import SciLean.Data.StructLike.Basic +import SciLean.Data.StructLike.Algebra diff --git a/SciLean/Data/StructLike/Algebra.lean b/SciLean/Data/StructLike/Algebra.lean new file mode 100644 index 00000000..dddf8210 --- /dev/null +++ b/SciLean/Data/StructLike/Algebra.lean @@ -0,0 +1,37 @@ +import SciLean.Core.Objects.SemiInnerProductSpace +import SciLean.Core.Objects.FinVec +import SciLean.Data.StructLike.Basic + +set_option linter.unusedVariables false + +namespace SciLean + +variable + (K : Type _) [IsROrC K] + {X : Type _} [SemiInnerProductSpace K X] + {Y : Type _} [SemiInnerProductSpace K Y] + {Z : Type _} [SemiInnerProductSpace K Z] + {W : Type _} [SemiInnerProductSpace K W] + {ι κ : Type _} [EnumType ι] [EnumType κ] + {E I : Type _} {EI : I → Type _} + [StructLike E I EI] + {F J : Type _} {FJ : J → Type _} + [StructLike F J FJ] + + +instance [∀ i, Vec K (EI i)] [∀ j, Vec K (FJ j)] (i : I ⊕ J) : Vec K (Prod.TypeFun EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +instance [∀ i, SemiInnerProductSpace K (EI i)] [∀ j, SemiInnerProductSpace K (FJ j)] (i : I ⊕ J) + : SemiInnerProductSpace K (Prod.TypeFun EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance + +instance [∀ i, FinVec ι K (EI i)] [∀ j, FinVec ι K (FJ j)] (i : I ⊕ J) + : FinVec ι K (Prod.TypeFun EI FJ i) := + match i with + | .inl _ => by infer_instance + | .inr _ => by infer_instance diff --git a/SciLean/Data/StructLike/Basic.lean b/SciLean/Data/StructLike/Basic.lean new file mode 100644 index 00000000..141c1dc6 --- /dev/null +++ b/SciLean/Data/StructLike/Basic.lean @@ -0,0 +1,110 @@ +import Mathlib.Init.Function +import Mathlib.Algebra.Group.Basic +import SciLean.Util.SorryProof + +namespace SciLean + +open Function + +class StructLike (E : Sort _) (I : outParam (Sort _)) (EI : outParam <| I → Sort _) where + proj : E → (i : I) → (EI i) + make : ((i : I) → (EI i)) → E + modify : (i : I) → (EI i → EI i) → (E → E) + left_inv : LeftInverse proj intro + right_inv : RightInverse proj intro + -- TODO: theorem about modify + + +/-- Every type is `StructLike` with `Unit` as index set. +-/ +instance (priority:=low) : StructLike α Unit (fun _ => α) where + proj := fun x _ => x + make := fun f => f () + modify := fun _ f x => f x + left_inv := sorry_proof + right_inv := sorry_proof + +-------------------------------------------------------------------------------- +-- Pi -------------------------------------------------------------------------- +-------------------------------------------------------------------------------- + +instance (priority:=low) + (I : Type _) (E : I → Type _) + (J : I → Type _) (EJ : (i : I) → (J i) → Type _) + [∀ (i : I), StructLike (E i) (J i) (EJ i)] [DecidableEq I] + : StructLike (∀ i, E i) ((i : I) × (J i)) (fun ⟨i,j⟩ => EJ i j) where + proj := fun f ⟨i,j⟩ => StructLike.proj (f i) j + make := fun f i => StructLike.make fun j => f ⟨i,j⟩ + modify := fun ⟨i,j⟩ f x i' => + if h : i'=i then + StructLike.modify (h▸j) (h▸f) (x i') + else + (x i') + left_inv := sorry_proof + right_inv := sorry_proof + +instance + (E I J : Type _) (EI : I → Type _) + [StructLike E I EI] [DecidableEq J] + : StructLike (J → E) (J×I) (fun (j,i) => EI i) where + proj := fun f (j,i) => StructLike.proj (f j) i + make := fun f j => StructLike.make fun i => f (j,i) + modify := fun (j,i) f x j' => + if h : j=j' then + StructLike.modify i f (x j) + else + (x j') + + left_inv := sorry_proof + right_inv := sorry_proof + + +-------------------------------------------------------------------------------- +-- Prod ------------------------------------------------------------------------ +-------------------------------------------------------------------------------- + +abbrev _root_.Prod.TypeFun {I J: Type _} (EI : I → Type _) (FJ : J → Type _) (i : Sum I J) : Type _ := + match i with + | .inl a => EI a + | .inr b => FJ b + +instance [StructLike E I EI] [StructLike F J FJ] + : StructLike (E×F) (Sum I J) (Prod.TypeFun EI FJ) where + proj := fun (x,y) i => + match i with + | .inl a => StructLike.proj x a + | .inr b => StructLike.proj y b + make := fun f => (StructLike.make (fun a => f (.inl a)), + StructLike.make (fun b => f (.inr b))) + modify := fun i f (x,y) => + match i with + | .inl a => (StructLike.modify a f x, y) + | .inr b => (x, StructLike.modify b f y) + left_inv := sorry_proof + right_inv := sorry_proof + + + +-------------------------------------------------------------------------------- +-- TODO: Add some lawfulness w.r.t. to +,•,0 + +@[simp] +theorem StruckLike.make_zero + {X I : Type} {XI : I → Type} [StructLike X I XI] + [Zero X] [∀ i, Zero (XI i)] + : StructLike.make (E:=X) (fun _ => 0) + = + 0 := +by + sorry + + +@[simp] +theorem StruckLike.make_add + {X I : Type} {XI : I → Type} [StructLike X I XI] + [Zero X] [∀ i, Zero (XI i)] + : StructLike.make (E:=X) (fun _ => 0) + = + 0 := +by + sorry diff --git a/SciLean/Data/TypeWithProj.lean b/SciLean/Data/TypeWithProj.lean index 4e453ec3..75f62d57 100644 --- a/SciLean/Data/TypeWithProj.lean +++ b/SciLean/Data/TypeWithProj.lean @@ -14,7 +14,6 @@ class TypeWithProj (F : Sort _) (I : outParam (Sort _)) (E : outParam <| I → S -- TODO: theorem about modify - -------------------------------------------------------------------------------- -- Prod ------------------------------------------------------------------------ --------------------------------------------------------------------------------