From f8f19407d9abc103c8bd27a29de6b07bab839a32 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Sun, 17 Sep 2023 18:53:31 +0200 Subject: [PATCH] lots of progress on pi rules for ftrans --- .../FunctionPropositions/Diffeomorphism.lean | 2 +- .../Core/FunctionTransformations/CDeriv.lean | 8 +- .../Core/FunctionTransformations/InvFun.lean | 4 +- .../FunctionTransformations/RevCDeriv.lean | 84 ++++++-- SciLean/Core/Monads/FwdDerivMonad.lean | 4 +- SciLean/Core/Monads/RevDerivMonad.lean | 3 +- SciLean/Core/Notation/Autodiff.lean | 3 +- SciLean/Core/Simp/Sum.lean | 100 ++++----- SciLean/Data/ArrayType/Properties.lean | 17 ++ SciLean/Data/Curry.lean | 26 ++- SciLean/Tactic/FTrans/Basic.lean | 125 +++++++---- SciLean/Tactic/FTrans/Init.lean | 200 ++++++++++++++---- SciLean/Tactic/StructuralInverse.lean | 8 +- SciLean/Tactic/StructureDecomposition.lean | 97 ++++++--- test/basic_gradients.lean | 101 ++++++++- 15 files changed, 578 insertions(+), 204 deletions(-) diff --git a/SciLean/Core/FunctionPropositions/Diffeomorphism.lean b/SciLean/Core/FunctionPropositions/Diffeomorphism.lean index 0723b5ae..cd4a98aa 100644 --- a/SciLean/Core/FunctionPropositions/Diffeomorphism.lean +++ b/SciLean/Core/FunctionPropositions/Diffeomorphism.lean @@ -309,7 +309,7 @@ theorem Function.invFun.arg_f.cderiv_rule (f : X → Y → Z) (hf : ∀ x, Diffeomorphism K (f x)) (hf' : IsDifferentiable K (fun xy : X×Y => f xy.1 xy.2)) - : cderiv K (fun x => invFun (f x)) + : cderiv K (fun x z => invFun (f x) z) = fun x dx z => let y := invFun (f x) z diff --git a/SciLean/Core/FunctionTransformations/CDeriv.lean b/SciLean/Core/FunctionTransformations/CDeriv.lean index d0b8455b..806e7a8b 100644 --- a/SciLean/Core/FunctionTransformations/CDeriv.lean +++ b/SciLean/Core/FunctionTransformations/CDeriv.lean @@ -320,7 +320,7 @@ by sorry_proof @[ftrans] theorem id.arg_a.cderiv_rule - : cderiv K (id : X → X) + : cderiv K (fun x : X => id x) = fun _ => id := by unfold id; ftrans @@ -332,7 +332,7 @@ theorem id.arg_a.cderiv_rule theorem Function.comp.arg_a0.cderiv_rule_at (f : Y → Z) (g : X → Y) (x : X) (hf : IsDifferentiableAt K f (g x)) (hg : IsDifferentiableAt K g x) - : cderiv K (f ∘ g) x + : cderiv K (fun x => (f ∘ g) x) x = fun dx => cderiv K f (g x) (cderiv K g x dx) := @@ -343,7 +343,7 @@ by theorem Function.comp.arg_a0.cderiv_rule (f : Y → Z) (g : X → Y) (hf : IsDifferentiable K f) (hg : IsDifferentiable K g) - : cderiv K (f ∘ g) + : cderiv K (fun x => (f ∘ g) x) = fun x => cderiv K f (g x) ∘ (cderiv K g x) := by @@ -354,7 +354,7 @@ theorem Function.comp.arg_fg.cderiv_rule (f : W → Y → Z) (g : W → X → Y) (hf : IsDifferentiable K (fun wy : W×Y => f wy.1 wy.2)) (hg : IsDifferentiable K (fun wx : W×X => g wx.1 wx.2)) - : cderiv K (fun w => ((f w) ∘ (g w))) + : cderiv K (fun w x => ((f w) ∘ (g w)) x) = fun w dw x => let y := g w x diff --git a/SciLean/Core/FunctionTransformations/InvFun.lean b/SciLean/Core/FunctionTransformations/InvFun.lean index a068d043..b34d5a74 100644 --- a/SciLean/Core/FunctionTransformations/InvFun.lean +++ b/SciLean/Core/FunctionTransformations/InvFun.lean @@ -156,7 +156,7 @@ by sorry_proof @[ftrans] theorem id.arg_a.invFun_rule - : invFun (id : X → X) + : invFun (fun x : X => id x) = id := by unfold id; ftrans @@ -168,7 +168,7 @@ theorem id.arg_a.invFun_rule theorem Function.comp.arg_a0.invFun_rule (f : Y → Z) (g : X → Y) (hf : Bijective f) (hg : Bijective g) - : invFun (f ∘ g) + : invFun (fun x => (f ∘ g) x) = invFun g ∘ invFun f := by unfold Function.comp; ftrans diff --git a/SciLean/Core/FunctionTransformations/RevCDeriv.lean b/SciLean/Core/FunctionTransformations/RevCDeriv.lean index 34e006d0..a82f24e4 100644 --- a/SciLean/Core/FunctionTransformations/RevCDeriv.lean +++ b/SciLean/Core/FunctionTransformations/RevCDeriv.lean @@ -3,7 +3,7 @@ import SciLean.Core.FunctionPropositions.HasAdjDiff import SciLean.Core.FunctionTransformations.SemiAdjoint --- import SciLean.Data.EnumType +import SciLean.Data.Curry import SciLean.Util.Profile @@ -268,6 +268,19 @@ by sorry_proof +variable (X) +theorem pi_curryn_rule {IX : Type _} (Is : Type _) (n : Nat) [UncurryN n IX Is X] [CurryN n Is X IX] [SemiInnerProductSpace K IX] [EnumType Is] + (f : W → IX) (hf : HasAdjDiff K f) + : revCDeriv K (fun (w : W) (i : Is) => uncurryN n (f w) i) + = + fun w => + let ydf := revCDeriv K f w + (uncurryN n ydf.1, + fun dx => ydf.2 (curryN n dx)) := +by + sorry_proof +variable {X} + theorem pi_comp_rule_simple (f : Y → Z) (g : X → ι → Y) (hf : HasAdjDiff K f) @@ -478,31 +491,43 @@ by simp sorry_proof -theorem revCDeriv.pi_uncurry_rule {ι κ} [EnumType ι] [EnumType κ] - (f : X → ι → κ → Y) (hf : ∀ i j, HasAdjDiff K (f · i j)) - : (revCDeriv K fun x i j => f x i j) + +theorem pi_inv_rule + (f : X → κ → Y) (h : ι → κ) (h' : κ → ι) (hh : Function.Inverse h' h) + (hf : ∀ j, HasAdjDiff K (f · j)) + : (revCDeriv K fun x i => f x (h i)) = - fun x => - let ydf := revCDeriv K (fun x' (ij : ι×κ) => f x' ij.1 ij.2) x - (fun i j => ydf.1 (i,j), - fun dy => ydf.2 (fun ij : ι×κ => dy ij.1 ij.2)) := + fun x => + let ydf := revCDeriv K f x + (fun i => ydf.1 (h i), + fun dy => + ydf.2 (fun j => dy (h' j))) := by - have _ := fun i j => (hf i j).1 - have _ := fun i j => (hf i j).2 - unfold revCDeriv - funext _; ftrans; -- ftrans - semiAdjoint.pi_rule fails because of some universe issues - simp sorry_proof +theorem pi_rinv_rule' {ι₁ ι₂: Type _} [EnumType ι₁] [EnumType ι₂] + (f : X → κ → Y) (h : ι → κ) (h' : ι₁ → κ → ι) (hh : ∀ i₁, Function.RightInverse (h' i₁) h) + (p₁ : ι → ι₁) (p₂ : ι → ι₂) (q : ι₁ → ι₂ → ι) (dec : Meta.IsDecomposition p₁ p₂ q) + (hf : ∀ j, HasSemiAdjoint K (f · j)) + : (semiAdjoint K fun x i => f x (h i)) + = + fun x' => + let f' := semiAdjoint K f + f' fun j => ∑ i₁, (x' (h' i₁ j)) := +by + sorry_proof -theorem revCDeriv.pi_curry_rule {ι κ} [EnumType ι] [EnumType κ] - (f : X → ι → κ → Y) (hf : ∀ i j, HasAdjDiff K (f · i j)) - : (revCDeriv K fun x (ij : ι×κ) => f x ij.1 ij.2) +-- TODO these are not sufficient conditions for this to be true, we need that `h'` induces isomorphism `ι≃ι₁×κ` +theorem pi_rinv_rule {ι₁ : Type _} [EnumType ι₁] + (f : X → κ → Y) (h : ι → κ) (h' : ι₁ → κ → ι) (hh : ∀ i₁, Function.RightInverse (h' i₁) h) + (hf : ∀ j, HasAdjDiff K (f · j)) + : (revCDeriv K fun x i => f x (h i)) = - fun x => - let ydf := revCDeriv K (fun x' i j => f x' i j) x - (fun ij => ydf.1 ij.1 ij.2, - fun dy => ydf.2 (fun i j => dy (i,j))) := + fun x => + let ydf := revCDeriv K f x + (fun i => ydf.1 (h i), + fun dy => + ydf.2 (fun j => ∑ i₁, dy (h' i₁ j))) := by sorry_proof @@ -606,6 +631,12 @@ def ftransExt : FTransExt where #[ { proof := ← mkAppM ``pi_uncurry_rule #[K, f], origin := .decl ``pi_uncurry_rule, rfl := false} ] discharger e + piCurryNRule e f Is Y n := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``pi_curryn_rule #[K, Y, Is, mkNatLit n, f], origin := .decl ``pi_curryn_rule, rfl := false} ] + discharger e + piCompRule e f g := do let .some K := e.getArg? 0 | return none tryTheorems @@ -624,13 +655,24 @@ def ftransExt : FTransExt where #[ { proof := ← mkAppM ``pi_prod_rule #[K, f, g], origin := .decl ``pi_prod_rule, rfl := false} ] discharger e - piLetRule e f g := do let .some K := e.getArg? 0 | return none tryTheorems #[ { proof := ← mkAppM ``pi_let_rule #[K, f, g], origin := .decl ``pi_let_rule, rfl := false} ] discharger e + piInvRule e f inv := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``pi_inv_rule #[K, f, inv.f, inv.invFun, inv.is_inv], origin := .decl ``pi_inv_rule, rfl := false} ] + discharger e + + piRInvRule e f rinv := do + let .some K := e.getArg? 0 | return none + tryTheorems + #[ { proof := ← mkAppM ``pi_rinv_rule #[K, f, rinv.f, rinv.invFun, rinv.right_inv], origin := .decl ``pi_rinv_rule, rfl := false} ] + discharger e + discharger := discharger diff --git a/SciLean/Core/Monads/FwdDerivMonad.lean b/SciLean/Core/Monads/FwdDerivMonad.lean index f67b395a..cc4ce4db 100644 --- a/SciLean/Core/Monads/FwdDerivMonad.lean +++ b/SciLean/Core/Monads/FwdDerivMonad.lean @@ -503,10 +503,9 @@ theorem Pure.pure.arg_a0.fwdDerivM_rule by apply FwdDerivMonad.fwdDerivM_pure a0 ha0 - set_option linter.fpropDeclName false in @[fprop] -theorem Pure.pure.IsDifferentiableValM_rule (x : X) +theorem Pure.pure.arg.IsDifferentiableValM_rule (x : X) : IsDifferentiableValM K (pure (f:=m) x) := by unfold IsDifferentiableValM @@ -514,7 +513,6 @@ by fprop -set_option linter.ftransDeclName false in @[ftrans] theorem Pure.pure.fwdDerivValM_rule (x : X) : fwdDerivValM K (pure (f:=m) x) diff --git a/SciLean/Core/Monads/RevDerivMonad.lean b/SciLean/Core/Monads/RevDerivMonad.lean index 5ba42f59..aa720d29 100644 --- a/SciLean/Core/Monads/RevDerivMonad.lean +++ b/SciLean/Core/Monads/RevDerivMonad.lean @@ -534,9 +534,8 @@ by fprop -set_option linter.ftransDeclName false in @[ftrans] -theorem Pure.pure.revDerivValM_rule (x : X) +theorem Pure.pure.arg.revDerivValM_rule (x : X) : revDerivValM K (pure (f:=m) x) = pure (x,fun dy => pure 0) := diff --git a/SciLean/Core/Notation/Autodiff.lean b/SciLean/Core/Notation/Autodiff.lean index 1a6784e8..26e275ee 100644 --- a/SciLean/Core/Notation/Autodiff.lean +++ b/SciLean/Core/Notation/Autodiff.lean @@ -1,5 +1,6 @@ import SciLean.Core.Notation.Symdiff import SciLean.Tactic.LetNormalize +import SciLean.Data.Curry namespace SciLean @@ -10,7 +11,7 @@ macro "autodiff" : conv => do (simp (config := {failIfUnchanged := false, zeta := false}) only [cderiv_as_fwdCDeriv, scalarGradient, gradient, scalarCDeriv,revCDerivEval] ftrans only let_normalize - simp (config := {failIfUnchanged := false, zeta := false}))) + simp (config := {failIfUnchanged := false, zeta := false}) [uncurryN, UncurryN.uncurry, curryN, CurryN.curry])) macro "autodiff" : tactic => do `(tactic| conv => autodiff) diff --git a/SciLean/Core/Simp/Sum.lean b/SciLean/Core/Simp/Sum.lean index ab361d43..eae95750 100644 --- a/SciLean/Core/Simp/Sum.lean +++ b/SciLean/Core/Simp/Sum.lean @@ -4,53 +4,53 @@ namespace SciLean variable {ι κ} [EnumType ι] [EnumType κ] -@[simp] -theorem sum_if {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) - : (∑ i, if i = j then f i else 0) - = - f j - := sorry_proof - -@[simp] -theorem sum_if' {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) - : (∑ i, if j = i then f i else 0) - = - f j - := sorry_proof - -@[simp] -theorem sum_lambda_swap {α β : Type _} [AddCommMonoid β] (f : ι → α → β) - : ∑ i, (fun a => f i a) - = - fun a => ∑ i, f i a - := sorry_proof - - -@[simp] -theorem sum2_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) - : (∑ i, ∑ j, if ij = (i,j) then f i j else 0) - = - f ij.1 ij.2 - := sorry_proof - -@[simp] -theorem sum2'_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) - : (∑ j, ∑ i, if ij = (i,j) then f i j else 0) - = - f ij.1 ij.2 - := sorry_proof - - -@[simp] -theorem sum2_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) - : (∑ i, ∑ j, if (i,j) = ij then f i j else 0) - = - f ij.1 ij.2 - := sorry_proof - -@[simp] -theorem sum2'_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) - : (∑ j, ∑ i, if (i,j) = ij then f i j else 0) - = - f ij.1 ij.2 - := sorry_proof +-- @[simp] +-- theorem sum_if {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) +-- : (∑ i, if i = j then f i else 0) +-- = +-- f j +-- := sorry_proof + +-- @[simp] +-- theorem sum_if' {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι) +-- : (∑ i, if j = i then f i else 0) +-- = +-- f j +-- := sorry_proof + +-- @[simp] +-- theorem sum_lambda_swap {α β : Type _} [AddCommMonoid β] (f : ι → α → β) +-- : ∑ i, (fun a => f i a) +-- = +-- fun a => ∑ i, f i a +-- := sorry_proof + + +-- @[simp] +-- theorem sum2_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) +-- : (∑ i, ∑ j, if ij = (i,j) then f i j else 0) +-- = +-- f ij.1 ij.2 +-- := sorry_proof + +-- @[simp] +-- theorem sum2'_if {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) +-- : (∑ j, ∑ i, if ij = (i,j) then f i j else 0) +-- = +-- f ij.1 ij.2 +-- := sorry_proof + + +-- @[simp] +-- theorem sum2_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) +-- : (∑ i, ∑ j, if (i,j) = ij then f i j else 0) +-- = +-- f ij.1 ij.2 +-- := sorry_proof + +-- @[simp] +-- theorem sum2'_if' {β : Type _} [AddCommMonoid β] (f : ι → κ → β) (ij : ι×κ) +-- : (∑ j, ∑ i, if (i,j) = ij then f i j else 0) +-- = +-- f ij.1 ij.2 +-- := sorry_proof diff --git a/SciLean/Data/ArrayType/Properties.lean b/SciLean/Data/ArrayType/Properties.lean index 15d98c67..55d1d1b2 100644 --- a/SciLean/Data/ArrayType/Properties.lean +++ b/SciLean/Data/ArrayType/Properties.lean @@ -157,6 +157,23 @@ by have ⟨_,_⟩ := hf unfold revCDeriv; ftrans; ftrans; simp +@[ftrans] +theorem GetElem.getElem.arg_xs_idx.revCDeriv_rule + (f : X → Cont) (dom) + (hf : HasAdjDiff K f) + : revCDeriv K (fun x idx => getElem (f x) idx dom) + = + fun x => + let ydf := revCDeriv K f x + (fun idx => getElem ydf.1 idx dom, + fun delem => + let dx := introElem delem + ydf.2 dx) := +by + have ⟨_,_⟩ := hf + unfold revCDeriv; ftrans + sorry_proof + -- @[ftrans] -- this one is considered harmful as it introduces one hot vector theorem GetElem.getElem.arg_xs.revDerivUpdate_rule (f : X → Cont) (idx : Idx) (dom) diff --git a/SciLean/Data/Curry.lean b/SciLean/Data/Curry.lean index 6906a794..73edee41 100644 --- a/SciLean/Data/Curry.lean +++ b/SciLean/Data/Curry.lean @@ -1,6 +1,30 @@ - +import SciLean.Util.SorryProof namespace SciLean +-------------------------------------------------------------------------------- + +class FunNArgs (n : Nat) (F : Sort _) (Xs : outParam <| Sort _) (Y : outParam <| Sort _) where + uncurry : F → Xs → Y + curry : (Xs → Y) → F + is_equiv : curry ∘ uncurry = id ∧ uncurry ∘ curry = id + +attribute [reducible] FunNArgs.uncurry FunNArgs.curry + +@[reducible] +instance (priority := low) {X Y : Sort _} : FunNArgs 1 (X→Y) X Y where + uncurry := λ (f : X → Y) (x : X) => f x + curry := λ (f : X → Y) (x : X) => f x + is_equiv := by constructor <;> rfl + +@[reducible] +instance (priority := low) {X Xs Y F : Sort _} [fn : FunNArgs n F Xs Y] : FunNArgs (n+1) (X→F) (X×Xs) Y where + uncurry := λ (f : X → F) ((x,xs) : X×Xs) => fn.uncurry (n:=n) (f x) xs + curry := λ (f : X×Xs → Y) (x : X) => fn.curry (n:=n) (fun xs => f (x,xs)) + is_equiv := + by constructor + . funext x xs; simp; sorry_proof + . funext f (x,xs); simp; sorry_proof + -------------------------------------------------------------------------------- diff --git a/SciLean/Tactic/FTrans/Basic.lean b/SciLean/Tactic/FTrans/Basic.lean index cff16fd3..5bb1728f 100644 --- a/SciLean/Tactic/FTrans/Basic.lean +++ b/SciLean/Tactic/FTrans/Basic.lean @@ -285,7 +285,86 @@ def piLetCase (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) | _ => throwError "expected expression of the form `fun x i => let y := g x i; f x y i`" | _ => throwError "expected expression of the form `fun x i => let y := g x i; f x y i`" +private def peelOffBVarArgs (bvarId : Nat) (e : Expr) (n : Nat := 0) : Expr × Nat := + match e with + | .app f x => + if x.hasLooseBVar bvarId then + peelOffBVarArgs bvarId f (n+1) + else + (.app f x, n) + | e => (e, n) +def piBFVarAppCase (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) (ftrans : Expr → SimpM (Option Simp.Step)) : SimpM (Option Simp.Step) := do + match f with + | .lam xName xType (.lam iName iType body iBi) xBi => do + + let fn := body.getAppFn + if ¬(fn.isBVar || fn.isFVar) then + throwError "expected expression of the form `fun x i => f x i` where head of `f` is bvar or fvar" + + withLocalDecl xName default xType fun x => do + let g := (Expr.lam iName iType body iBi).instantiate1 x + let .some (Is,Y) := (← inferType g).arrow? + | throwError "unexpected type {← inferType g} in pi bvar/fvar app case +" + let (g',h') ← splitLambdaToComp g + + trace[Meta.Tactic.ftrans.step] "case pi change of variables\n{← ppExpr e}\n{← ppExpr g'}\n{← ppExpr h'}" + + let .some (hinv, goals) ← structuralInverse h' + | trace[Meta.Tactic.ftrans.step] "unable to invert {← ppExpr h'}" + return none + + match hinv with + | .full finv => + if (← isDefEq finv.invFun h') then + trace[Meta.Tactic.ftrans.step] "identity case, nothing to be done" + -- now we are expecting that we are dealing with expressions like + -- 1. bvar app - fun x i => f i.1 i.2.1 i.2.2 + -- 2. fvar app - fun x i => f x i.1 i.2.1 i.2.2 + let (body', n) := peelOffBVarArgs 0 body + + if body'.hasLooseBVar 0 then + trace[Meta.Tactic.ftrans.step] "unable to curry back trailing arguments" + return none + + let g := Expr.lam xName xType (body'.lowerLooseBVars 1 1) xBi + return ← ext.piCurryNRule e g Is Y n + else + trace[Meta.Tactic.ftrans.step] "computed inverse {← ppExpr finv.invFun}" + return ← ext.piInvRule e (← mkLambdaFVars #[x] g') finv + | .right rinv => + trace[Meta.Tactic.ftrans.step] "only right inverse, skipping for now" + return ← ext.piRInvRule e (← mkLambdaFVars #[x] g') rinv + + | _ => throwError "expected expression of the form `fun x i => f x i`" + + +def piConstAppCase (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) (ftrans : Expr → SimpM (Option Simp.Step)) : SimpM (Option Simp.Step) := do + match f with + | .lam xName xType (.lam iName iType body iBi) xBi => do + let fn := body.getAppFn + if ¬(fn.isConst) then + throwError "expected expression of the form `fun x i => f x i` where head of `f` is a constant" + + match fn with + | .const constName _ => + match (← getEnv).find? constName with + | none => return none + | some info => + let constArity := info.type.forallArity + let args := body.getAppArgs + + if args.size == constArity then + let (f',g') ← elemWiseSplitHighOrderLambdaToComp f + + if ¬(← isDefEq g' f) then + return ← ext.piElemWiseCompRule e f' g' + + return none + + | _ => throwError "expected expression of the form `fun x i => f x i` where head of `f` is a constant" + | _ => throwError "expected expression of the form `fun x i => f x i`" def piCase (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) (ftrans : Expr → SimpM (Option Simp.Step)) : SimpM (Option Simp.Step) := do @@ -334,47 +413,15 @@ def piCase (e : Expr) (ftransName : Name) (ext : FTransExt) (f : Expr) (ftrans : if ¬(← isDefEq f' f) then return ← ext.compRule e f' g' - match body.getAppFn with - | .bvar _ | .fvar _ => do - withLocalDecl xName default xType fun x => do - let g := (Expr.lam iName iType body iBi).instantiate1 x - let (g',h') ← splitLambdaToComp g - - trace[Meta.Tactic.ftrans.step] "case pi change of variables\n{← ppExpr e}\n{← ppExpr g'}\n{← ppExpr h'}" - - let .some (hinv, goals) ← structuralInverse h' - | trace[Meta.Tactic.ftrans.step] "unable to invert {← ppExpr h'}" - return none - - match hinv with - | .full finv => - if (← isDefEq finv.invFun h') then - trace[Meta.Tactic.ftrans.step] "identity case, nothing to be done" - return none - else - trace[Meta.Tactic.ftrans.step] "computed inverse {← ppExpr finv.invFun}" - return none - | .right rinv => - trace[Meta.Tactic.ftrans.step] "only right inverse, skipping for now" - return none - | Expr.const constName _ => do - match (← getEnv).find? constName with - | none => return none - | some info => - let constArity := info.type.forallArity - let args := body.getAppArgs - - if args.size == constArity then - let (f',g') ← elemWiseSplitHighOrderLambdaToComp f - - if ¬(← isDefEq g' f) then - return ← ext.piElemWiseCompRule e f' g' + let fn := body.getAppFn + + if (fn.isBVar || fn.isFVar) then + return ← piBFVarAppCase e ftransName ext f ftrans - return none + if (fn.isConst) then + return ← piConstAppCase e ftransName ext f ftrans - | _ => return none - - -- return none + return none | _ => throwError "expected expression of the form `fun x i => f x i`" diff --git a/SciLean/Tactic/FTrans/Init.lean b/SciLean/Tactic/FTrans/Init.lean index a862a452..7514a577 100644 --- a/SciLean/Tactic/FTrans/Init.lean +++ b/SciLean/Tactic/FTrans/Init.lean @@ -11,6 +11,9 @@ import SciLean.Lean.MergeMapDeclarationExtension import SciLean.Lean.Meta.Basic import SciLean.Tactic.StructuralInverse + +import SciLean.Data.Function +import SciLean.Data.ArraySet open Lean Meta.Simp Qq @@ -41,15 +44,14 @@ initialize registerOption `linter.ftransDeclName { defValue := true, descr := "s open Meta Simp -def _root_.Function.Inverse (g : β → α) (f : α → β) := - Function.LeftInverse g f ∧ Function.RightInverse g f - - /-- Data for `fun x i => f x (h i)` case when `h` is invertible -/ structure PiInvData where - -- {u v w u' v' w' : Level} - {X Y I I₁ I₂ J : Q(Type)} + {u v w w' : Level} + {X : Q(Type u)} + {Y : Q(Type v)} + {I : Q(Type w)} + {J : Q(Type w')} (f : Q($X → $J → $Y)) (h : Q($I → $J)) (h' : Q($J → $I)) @@ -98,6 +100,8 @@ structure FTransExt where piConstRule (expr f I : Expr) : SimpM (Option Simp.Step) := return none /-- Custom rule for transforming `fun x i j => f x i j` -/ piUncurryRule (expr f : Expr) : SimpM (Option Simp.Step) := return none + /-- Custom rule for transforming `fun x (is : Is) => uncurryN n (f x) is` where `uncurryN n (f x)` has type `Is → Y` -/ + piCurryNRule (expr f Is Y : Expr) (n : Nat) : SimpM (Option Simp.Step) := return none /-- Custom rule for transforming `fun x i => (f x i, g x i)` -/ piProdRule (expr f g : Expr) : SimpM (Option Simp.Step) := return none /-- Custom rule for transforming `fun x i => f (g x i) i` -/ @@ -109,9 +113,9 @@ structure FTransExt where /-- Custom rule for transforming `fun x i => f (g x i)` -/ piSimpleCompRule (expr f g : Expr) : SimpM (Option Simp.Step) := return none /-- Custom rule for transforming `fun x i => f x (h i)` when `h` has inverse -/ - piInvRule (expr : Expr) (data : PiInvData) : SimpM (Option Simp.Step) := return none + piInvRule (expr f : Expr) (inv : FullInverse) : SimpM (Option Simp.Step) := return none /-- Custom rule for transforming `fun x i => f x (h i)` when `h` has left inverse -/ - piLInvRule (expr : Expr) (data : PiLInvData) : SimpM (Option Simp.Step) := return none + piRInvRule (expr f : Expr) (rinv : RightInverse) : SimpM (Option Simp.Step) := return none /-- Custom discharger for this function transformation -/ discharger : Expr → SimpM (Option Expr) @@ -193,24 +197,56 @@ def getFTransFun? (e : Expr) : CoreM (Option Expr) := do initialize registerTraceClass `trace.Tactic.ftrans.new_property +structure FTransRule where + -- ftransName : Name + -- constName : Name + ruleName : Name + priority : Nat := 1000 + /-- Set of active argument indices in this rule + For example: + - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) (g x)) = ...` has `argIds = #[4,5]` + - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) y) = ...` has `argIds = #[4]` -/ + argIds : ArraySet Nat + /-- Set of trailing argument indices in this rule + For example: + - rule `∂ (fun x i => @getElem _ _ _ _ _ (f x) i dom ` has `piArgs = #[6]` + - rule `∂ (fun f x => @Function.invFun _ _ _ f x` has `piArgs = #[4]` + - rule `∂ (fun x => (f x) + (g x)` has `piArgs = #[]` + -/ + piIds : ArraySet Nat + +def FTransRule.cmp (a b : FTransRule) : Ordering := + match a.piIds.lexOrd b.piIds with + | .lt => .lt + | .gt => .gt + | .eq => + match a.argIds.lexOrd b.argIds with + | .lt => .lt + | .gt => .gt + | .eq => + match compare a.priority b.priority with + | .lt => .lt + | .gt => .gt + | .eq => a.ruleName.quickCmp b.ruleName + local instance : Ord Name := ⟨Name.quickCmp⟩ /-- -This holds a collection of property theorems for a fixed constant +This holds a collection of function transformation rules for a fixed constant -/ -def FTransRules := Std.RBMap Name (Std.RBSet Name compare /- maybe (Std.RBSet SimTheorem ...) -/) compare +def FTransRules := Std.RBMap Name (Std.RBSet FTransRule FTransRule.cmp) compare namespace FTransRules instance : Inhabited FTransRules := by unfold FTransRules; infer_instance - instance : ToString FTransRules := ⟨fun s => toString (s.toList.map fun (n,r) => (n,r.toList))⟩ + -- instance : ToString FTransRules := ⟨fun s => toString (s.toList.map fun (n,r) => (n,r.toList))⟩ variable (fp : FTransRules) - def insert (property : Name) (thrm : Name) : FTransRules := - fp.alter property (λ p? => + def insert (ftransName : Name) (rule : FTransRule) : FTransRules := + fp.alter ftransName (λ p? => match p? with - | some p => some (p.insert thrm) - | none => some (Std.RBSet.empty.insert thrm)) + | some p => some (p.insert rule) + | none => some (Std.RBSet.empty.insert rule)) def empty : FTransRules := Std.RBMap.empty @@ -251,36 +287,111 @@ To register function transformation call: ``` where is name of the function transformation and is corresponding `FTrans.Info`. " - let .some funName ← getFunHeadConst? f - | throwError "Function being transformed is in invalid form!" + + -- in rare cases `f` is not a function + -- for example this is case for monadic `fwdDerivValM` + if ¬f.isLambda then + let .const funName _ := f.getAppFn + | throwError "Function being transformed is in invalid form! The head of {← ppExpr f} is not a constant but it is {f.ctorName}!" + + if (← inferType f).isForall then + throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!" + + let ftransRule : FTransRule := { + ruleName := ruleName + argIds := #[].toArraySet + piIds := #[].toArraySet + } + + FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule) + + else + + lambdaTelescope f fun xs b => do + let .some x := xs[0]? + | throwError "Function being transformed is in invalid form! It has to be a lambda function!" + let .const funName _ := b.getAppFn + | throwError "Function being transformed is in invalid form! The head of {← ppExpr b} is not a constant but it is {b.ctorName}!" + if xs.size > 2 then + throwError "Function being transformed is in invalid form! Only one trailing argument is currently supported!" + + let xId := x.fvarId! + + let arity ← getConstArity funName + let args := b.getAppArgs + + if args.size ≠ arity then + throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!" + + let argIds := + args.mapIdx (fun i arg => if arg.containsFVar xId then .some i.1 else none) + |>.filterMap id + + let piIds ← + if let .some y := xs[1]? then + let yId := y.fvarId! + let piArgs := + args.mapIdx (fun i arg => if arg.containsFVar yId then .some (i.1,arg) else none) + |>.filterMap id + if piArgs.size ≠ 1 then + throwError "Function being transformed is in invalid form! Trailing argument `{← ppExpr y}` can appear in only one argument, but it appears in `{← piArgs.mapM (fun (_,arg) => ppExpr arg)}`" + pure (piArgs.map (fun (i,_) => i)) + else + pure #[] + + let argNames ← getConstArgNames funName (fixAnonymousNames := true) + let depName := + argIds.map (fun i => argNames[i]?.getD default) + |>.foldl (·++·.toString) "" + + let piName := + piIds.map (fun i => argNames[i]?.getD default) + |>.foldl (·++·.toString) "" + + let argSuffix := + "arg" ++ if depName ≠ "" then "_" ++ depName else "" + ++ if piName ≠ "" then "_" ++ piName else "" + + let suggestedRuleName := + funName |>.append argSuffix + |>.append (transName.getString.append "_rule") + + if (← getBoolOption `linter.ftransDeclName true) && + ¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then + logWarning s!"suggested name for this rule is {suggestedRuleName}" + + let ftransRule : FTransRule := { + ruleName := ruleName + argIds := argIds.toArraySet + piIds := piIds.toArraySet + } + + FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule) + ) - let depArgIds := - match f with - | .lam _ _ body _ => - body.getAppArgs - |>.mapIdx (fun i arg => if arg.hasLooseBVars then Option.some i.1 else none) - |>.filterMap id - | _ => #[f.getAppNumArgs] - let argNames ← getConstArgNames funName (fixAnonymousNames := true) - let depNames := depArgIds.map (fun i => argNames[i]?.getD default) +def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do - let argSuffix := "arg_" ++ depNames.foldl (·++·.toString) "" + let .some rules ← FTransRulesExt.find? funName + | return #[] - let suggestedRuleName := - funName |>.append argSuffix - |>.append (transName.getString.append "_rule") + let .some rules := rules.find? ftransName + | return #[] + let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do + if r.piIds.size ≠ 0 then + return none + else + return .some { + proof := mkConst r.ruleName + origin := .decl r.ruleName + rfl := false + } - if (← getBoolOption `linter.ftransDeclName true) && - ¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then - logWarning s!"suggested name for this rule is {suggestedRuleName}" + return rules.toArray - FTransRulesExt.insert funName (FTransRules.empty.insert transName ruleName) - ) -open Meta in -def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do +def getFTransPiRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do let .some rules ← FTransRulesExt.find? funName | return #[] @@ -288,12 +399,15 @@ def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do let .some rules := rules.find? ftransName | return #[] - let rules : List SimpTheorem ← rules.toList.mapM fun r => do - return { - proof := mkConst r - origin := .decl r - rfl := false - } + let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do + if r.piIds.size = 0 then + return none + else + return .some { + proof := mkConst r.ruleName + origin := .decl r.ruleName + rfl := false + } return rules.toArray diff --git a/SciLean/Tactic/StructuralInverse.lean b/SciLean/Tactic/StructuralInverse.lean index 905d658b..8380d50e 100644 --- a/SciLean/Tactic/StructuralInverse.lean +++ b/SciLean/Tactic/StructuralInverse.lean @@ -1,7 +1,6 @@ import SciLean.Tactic.StructureDecomposition import SciLean.Tactic.LetNormalize - -import Mathlib.Logic.Function.Basic +import SciLean.Data.Function namespace SciLean.Meta @@ -164,6 +163,7 @@ structure FullInverse where {Y : Q(Type v)} (f : Q($X → $Y)) (invFun : Q($Y → $X)) + (is_inv : Q(Function.Inverse $invFun $f)) open Qq /-- @@ -250,14 +250,14 @@ def structuralInverse (f : Expr) : MetaM (Option (FunctionInverse × Array MVarI let f : Q($X → $Y) := f let b := (← mkAppM' xmk eqInv.xVals).headBeta - if eqInv.unresolvedXVars.size = 0 then let invFun : Q($Y → $X) ← mkLambdaFVars (#[y] ++ eqInv.letVars) b let invFun ← Meta.LetNormalize.letNormalize invFun {removeLambdaLet:=false} + let is_inv ← mkSorry q(Function.Inverse $invFun $f) false let finv : FullInverse := { - u := u, v := v, X := X, Y := Y, f := f, invFun := invFun + u := u, v := v, X := X, Y := Y, f := f, invFun := invFun, is_inv := is_inv } return .some (.full finv, goals) diff --git a/SciLean/Tactic/StructureDecomposition.lean b/SciLean/Tactic/StructureDecomposition.lean index 7e1ee591..91f6c408 100644 --- a/SciLean/Tactic/StructureDecomposition.lean +++ b/SciLean/Tactic/StructureDecomposition.lean @@ -7,6 +7,43 @@ set_option linter.unusedVariables false open Lean Meta Qq +/-- Is it structure containing only plain data i.e. no propositions, no types, no dependent types, no functions +-/ +def simpleDataStructure (structName : Name) : MetaM Bool := do + + let ctor := getStructureCtor (← getEnv) structName + + let .some info := getStructureInfo? (← getEnv) structName + | return false + + for finfo in info.fieldInfo do + let pinfo ← getConstInfo finfo.projFn + let stop ← forallTelescope pinfo.type fun xs b => do + if xs.size ≠ ctor.numParams + 1 then -- functions + pure true + else if b.isSort then -- types + pure true + else if (← isProp b) then -- proposition + pure true + else if (b.containsFVar xs[ctor.numParams]!.fvarId!) then -- dependent types + pure true + else + pure false + + if stop then + return false + + return true + + +private def buildMk (mk : Expr) (mks : List Expr) (vars vals : Array Expr) : MetaM Expr := + match mks with + | [] => mkLambdaFVars vars (mkAppN mk vals) + | mk' :: mks' => + lambdaTelescope mk' fun xs b => + buildMk mk mks' (vars++xs) (vals.push b) + + /-- Decomposes an element `e` of possible nested structure and returns a function put it back together. For example, calling this function on `x : (Nat×Nat)×Nat` returns `(#[x.1.1, x.1.2, x.1], fun a b c => ((a,b),c))` @@ -18,6 +55,9 @@ partial def splitStructureElem (e : Expr) : MetaM (Array Expr × Expr) := do let .const structName lvls := E.getAppFn' | return (#[e], idE) + if ¬(← simpleDataStructure structName) then + return (#[e], idE) + let .some info := getStructureInfo? (← getEnv) structName | return (#[e], idE) @@ -34,57 +74,55 @@ partial def splitStructureElem (e : Expr) : MetaM (Array Expr × Expr) := do mkAppM projFn #[e] >>= reduceProjOfCtor) let (eis,mks) := (← eis.mapM splitStructureElem).unzip - + -- this implementation of combining `mks` together works but it is probably not very efficient let mk := mkAppN (.const ctorVal.name lvls) E.getAppArgs - let mk ← mks.foldlM (init:=mk) - (fun mk mki => do - forallTelescope (← inferType mki) fun xs _ => do - let mk ← mkAppM' mk #[(←mkAppM' mki xs).headBeta] - forallTelescope (← inferType mk) fun ys _ => do - mkLambdaFVars (ys++xs) (←mkAppM' mk ys).headBeta) + let mk ← buildMk mk mks.toList #[] #[] return (eis.flatten, mk) + /-- Decomposes an element `e` that is a nested application of constructors For example, calling this function on `x : (Nat×Nat)×Nat` returns `(#[x.1.1, x.1.2, x.1], fun a b c => ((a,b),c))` -/ -partial def splitByCtors (e : Expr) : MetaM (Array Expr × Expr) := do +partial def splitByCtors (e : Expr) : MetaM (Array Expr × Array Expr × Expr) := do + let E ← inferType e let idE := .lam `x E (.bvar 0) default let .const structName lvls := E.getAppFn' - | return (#[e], idE) + | return (#[e], #[idE], idE) let .some info := getStructureInfo? (← getEnv) structName - | return (#[e], idE) + | return (#[e], #[idE], idE) let ctorVal := getStructureCtor (← getEnv) structName - if E.getAppNumArgs != ctorVal.numParams then - return (#[e], idE) + let fn := e.getAppFn + let args := e.getAppArgs' - if ctorVal.numFields ≤ 1 then - return (#[e], idE) - - let eis ← info.fieldNames.mapM (fun fname => do - let projFn := getProjFnForField? (← getEnv) structName fname |>.get! - mkAppM projFn #[e] >>= reduceProjOfCtor) + if fn.constName? ≠ .some ctorVal.name then + return (#[e], #[idE], idE) + + if args.size ≠ ctorVal.numParams + ctorVal.numFields then + return (#[e], #[idE], idE) - let (eis,mks) := (← eis.mapM splitStructureElem).unzip + let mk := mkAppN fn args[0:ctorVal.numParams] - -- this implementation of combining `mks` together works but it is probably not very efficient - let mk := mkAppN (.const ctorVal.name lvls) E.getAppArgs - let mk ← mks.foldlM (init:=mk) - (fun mk mki => do - forallTelescope (← inferType mki) fun xs _ => do - let mk ← mkAppM' mk #[(←mkAppM' mki xs).headBeta] - forallTelescope (← inferType mk) fun ys _ => do - mkLambdaFVars (ys++xs) (←mkAppM' mk ys).headBeta) - - return (eis.flatten, mk) + let fields : Array _ := args[ctorVal.numParams : ctorVal.numParams + ctorVal.numFields] + let (eis, tmp) := (← fields.mapM splitByCtors).unzip + let (projs, mks) := tmp.unzip + + let projs := projs + |>.mapIdx (fun idx projs' => + projs'.map (fun proj' => Expr.lam `x E (proj'.app (Expr.proj structName idx (.bvar 0))).headBeta default)) + |>.flatten + + let mk ← buildMk mk mks.toList #[] #[] + + return (eis.flatten, projs, mk) structure IsDecomposition (p₁ : X → X₁) (p₂ : X → X₂) (q : X₁ → X₂ → X) : Prop where @@ -92,6 +130,7 @@ structure IsDecomposition (p₁ : X → X₁) (p₂ : X → X₂) (q : X₁ → mk_proj₁ : ∀ x₁ x₂, p₁ (q x₁ x₂) = x₁ mk_proj₂ : ∀ x₁ x₂, p₂ (q x₁ x₂) = x₂ + structure StructureDecomposition where {u v w : Level} X : Q(Type u) diff --git a/test/basic_gradients.lean b/test/basic_gradients.lean index e2c4c2d4..442b4a74 100644 --- a/test/basic_gradients.lean +++ b/test/basic_gradients.lean @@ -73,7 +73,7 @@ example = fun x => ⊞ _ => (1:K) := by - (conv => lhs; autodiff) + (conv => lhs; unfold scalarGradient; ftrans only; autodiff) example : (∇ (x : Fin 10 → K), ∑ i, ‖x i‖₂²) @@ -82,10 +82,103 @@ example by (conv => lhs; autodiff) -example (A : Fin 5 → Fin 10 → K) - : (∇ (x : Fin 10 → K), fun i => ∑ j, A i j * x j) +set_option trace.Meta.Tactic.simp.rewrite true in +example (A : Idx 5 → Idx 10 → K) + : (∇ (x : K ^ Idx 10), fun i => ∑ j, A i j * x[j]) = - fun _ dy j => ∑ i, A i j * dy i := + fun _ dy => ⊞ j => ∑ i, A i j * dy i := by + (conv => + lhs + unfold gradient + ftrans only) + + +example + : (∇ (x : Fin 10 → K), fun i => x i) + = + fun _ dx => dx := +by + (conv => lhs; autodiff) + + +example + : (∇ (x : Fin 5 → Fin 10 → K), fun i j => x i j) + = + fun _ dx => dx := +by + (conv => lhs; autodiff) + + +example + : (∇ (x : Fin 5 → Fin 10 → Fin 15→ K), fun i j k => x i j k) + = + fun _ dx => dx := +by + (conv => lhs; autodiff) + + +example + : (∇ (x : Fin 5 → Fin 10 → Fin 15→ K), fun k i j => x i j k) + = + fun _ dx i j k => dx k i j := +by + (conv => lhs; autodiff) + + +-- TODO remove `hf'` assumption, is should be automatically deduced from `hf` once #23 is resolved +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, HasAdjDiff K (f · i j k)) + (hf' : HasAdjDiff K f) + : (∇ (x : X), fun k i j => f x i j k) + = + fun x dy => + let ydf := <∂ f x + ydf.2 fun i j k => dy k i j := +by + (conv => lhs; autodiff) + + + +set_option trace.Meta.Tactic.simp.rewrite true in +example + : (<∂ (x : K ^ Idx 10), fun (ij : Idx 5 × Idx 10) => x[ij.snd]) + = + 0 := +by + conv => + lhs + ftrans only + let_normalize + + +set_option trace.Meta.Tactic.simp.rewrite +example + : (∇ (x : K ^ Idx 10), fun i => x[i]) + = + fun _ dx => ⊞ i => dx i := +by + (conv => lhs; unfold gradient; ftrans only) + + +example + : (∇ (x : K ^ (Idx 10 × Idx 5)), fun i j => x[(i,j)]) + = + fun _ dx => ⊞ ij => dx ij.1 ij.2 := +by (conv => lhs; autodiff) + +example + : (∇ (x : K ^ (Idx 5 × Idx 10 × Idx 15)), fun i j k => x[(k,i,j)]) + = + fun _ dx => ⊞ kij => dx kij.2.1 kij.2.2 kij.1 := +by + (conv => lhs; autodiff) + + +example + : (∇ (x : Fin 5 → Fin 10 → Fin 15→ K), fun k i j => x i j k) + = + fun _ dx i j k => dx k i j := +by + (conv => lhs; autodiff)