From d18a62b6d894de1820038917ea73bbfcebf179db Mon Sep 17 00:00:00 2001 From: lecopivo Date: Wed, 27 Sep 2023 14:20:40 -0400 Subject: [PATCH] some modifications to `fprop` in fvar app case, partial solution to #23 --- .../Core/FunctionPropositions/HasAdjDiff.lean | 4 +- .../FunctionPropositions/HasAdjDiffAt.lean | 5 +- .../Core/FunctionTransformations/Adjoint.lean | 4 +- SciLean/Core/Meta/GenerateFwdCDeriv.lean | 12 +- SciLean/Core/Meta/GenerateRevCDeriv.lean | 7 +- SciLean/Lean/Meta/Basic.lean | 10 + SciLean/Tactic/AnalyzeLambda.lean | 309 ++++++++++++++++++ SciLean/Tactic/FProp/Basic.lean | 97 +++++- test/generate_ftrans.lean | 21 ++ test/issues/23.lean | 44 ++- 10 files changed, 474 insertions(+), 39 deletions(-) create mode 100644 SciLean/Tactic/AnalyzeLambda.lean diff --git a/SciLean/Core/FunctionPropositions/HasAdjDiff.lean b/SciLean/Core/FunctionPropositions/HasAdjDiff.lean index e5650c70..de62719d 100644 --- a/SciLean/Core/FunctionPropositions/HasAdjDiff.lean +++ b/SciLean/Core/FunctionPropositions/HasAdjDiff.lean @@ -362,10 +362,10 @@ by @[fprop] theorem SciLean.EnumType.sum.arg_f.HasAdjDiff_rule - (f : X → ι → Y) (hf : HasAdjDiff K f) + (f : X → ι → Y) (hf : ∀ i, HasAdjDiff K (f · i)) : HasAdjDiff K (fun x => ∑ i, f x i) := by - have ⟨_,_⟩ := hf + have := fun i => (hf i).1 constructor; fprop; ftrans; fprop diff --git a/SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean b/SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean index 98bedeb5..ec96a1a3 100644 --- a/SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean +++ b/SciLean/Core/FunctionPropositions/HasAdjDiffAt.lean @@ -368,10 +368,11 @@ by @[fprop] theorem SciLean.EnumType.sum.arg_f.HasAdjDiffAt_rule - (f : X → ι → Y) (x : X) (hf : HasAdjDiffAt K f x) + (f : X → ι → Y) (x : X) (hf : ∀ i, HasAdjDiffAt K (f · i) x) : HasAdjDiffAt K (fun x => ∑ i, f x i) x := by - have ⟨_,_⟩ := hf + have := fun i => (hf i).1 + have := fun i => (hf i).2 constructor; fprop; ftrans; fprop diff --git a/SciLean/Core/FunctionTransformations/Adjoint.lean b/SciLean/Core/FunctionTransformations/Adjoint.lean index 82f8e817..38bb8955 100644 --- a/SciLean/Core/FunctionTransformations/Adjoint.lean +++ b/SciLean/Core/FunctionTransformations/Adjoint.lean @@ -113,7 +113,7 @@ by open BigOperators in theorem pi_rule - (f : X → (i : ι) → E i) (hf : IsContinuousLinearMap K f) + (f : X → (i : ι) → E i) (hf : ∀ i, IsContinuousLinearMap K (f · i)) : ((fun (x : X) =>L[K] fun (i : ι) => f x i) : X →L[K] PiLp 2 E)† = (fun x' =>L[K] ∑ i, (fun x =>L[K] f x i)† (x' i)) @@ -463,7 +463,7 @@ by open BigOperators in @[ftrans] theorem Finset.sum.arg_f.adjoint_rule - (f : X → ι → Y) (hf : IsContinuousLinearMap K f) + (f : X → ι → Y) (hf : ∀ i, IsContinuousLinearMap K (f · i)) : (fun x =>L[K] ∑ i, f x i)† = (fun y =>L[K] ∑ i, (fun x =>L[K] f x i)† y) := diff --git a/SciLean/Core/Meta/GenerateFwdCDeriv.lean b/SciLean/Core/Meta/GenerateFwdCDeriv.lean index b5fe0d8d..a46ebf49 100644 --- a/SciLean/Core/Meta/GenerateFwdCDeriv.lean +++ b/SciLean/Core/Meta/GenerateFwdCDeriv.lean @@ -57,7 +57,11 @@ def generateFwdCDeriv (constName : Name) (mainNames trailingNames : Array Name) mkLocalDecls (n:=TermElabM) (mainNames.map (fun n => n.appendBefore "h")) .default - (← mainArgs.mapM (fun arg => mkAppM ``IsDifferentiable #[K,arg])) + (← mainArgs.mapM (fun arg => do + lambdaTelescope (← etaExpand arg) fun xs b => do + let f := (← mkLambdaFVars #[xs[0]!] b).eta + let prop ← mkAppM ``IsDifferentiable #[K,f] + mkForallFVars xs[1:] prop)) withLocalDecls decls fun mainArgProps => do @@ -73,11 +77,11 @@ def generateFwdCDeriv (constName : Name) (mainNames trailingNames : Array Name) let (rhs, proof) ← elabConvRewrite rhs conv let isDiffProof ← elabProof isDiff tac - let .lam _ _ (.lam _ _ rhsBody _) _ := rhs - | throwError "unexpected result after function transformation, expecting `fun w dw => ...` but got\n{←ppExpr rhs}" + -- let .lam _ _ (.lam _ _ rhsBody _) _ := rhs + -- | throwError "unexpected result after function transformation, expecting `fun w dw => ...` but got\n{←ppExpr rhs}" withLocalDecl `dw .default W fun dw => do - let rhsBody := rhsBody.instantiate #[dw,w] + let rhsBody := (mkAppN rhs #[w,dw]).headBeta -- rhsBody.instantiate #[dw,w] let dargs ← mainArgs.mapM (fun arg => mkAppM ``fwdCDeriv #[K,arg,w,dw]) let fwdDerivFun ← diff --git a/SciLean/Core/Meta/GenerateRevCDeriv.lean b/SciLean/Core/Meta/GenerateRevCDeriv.lean index 0df714de..0bb9feda 100644 --- a/SciLean/Core/Meta/GenerateRevCDeriv.lean +++ b/SciLean/Core/Meta/GenerateRevCDeriv.lean @@ -58,7 +58,12 @@ def generateRevCDeriv (constName : Name) (mainNames trailingNames : Array Name) mkLocalDecls (n:=TermElabM) (mainNames.map (fun n => n.appendBefore "h")) .default - (← mainArgs.mapM (fun arg => mkAppM ``HasAdjDiff #[K,arg])) + (← mainArgs.mapM (fun arg => do + lambdaTelescope (← etaExpand arg) fun xs b => do + let f := (← mkLambdaFVars #[xs[0]!] b).eta + let prop ← mkAppM ``HasAdjDiff #[K,f] + mkForallFVars xs[1:] prop)) + withLocalDecls decls fun mainArgProps => do diff --git a/SciLean/Lean/Meta/Basic.lean b/SciLean/Lean/Meta/Basic.lean index d63c404e..670ae3b8 100644 --- a/SciLean/Lean/Meta/Basic.lean +++ b/SciLean/Lean/Meta/Basic.lean @@ -103,6 +103,16 @@ def getExplicitArgs (e : Expr) : MetaM (Option (Name×Array Expr)) := do return (funName, explicitArgs) +/-- Eta expansion, but adds at most `n` binders +-/ +def etaExpandN (e : Expr) (n : Nat) : MetaM Expr := + withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs[0:n] (mkAppN e xs[0:n]) + +/-- Eta expansion, it also beta reduces the body +-/ +def etaExpand' (e : Expr) : MetaM Expr := + withDefault do forallTelescopeReducing (← inferType e) fun xs _ => mkLambdaFVars xs (mkAppN e xs).headBeta + /-- Same as `mkAppM` but does not leave trailing implicit arguments. diff --git a/SciLean/Tactic/AnalyzeLambda.lean b/SciLean/Tactic/AnalyzeLambda.lean new file mode 100644 index 00000000..da0ca045 --- /dev/null +++ b/SciLean/Tactic/AnalyzeLambda.lean @@ -0,0 +1,309 @@ +import Lean +import Qq + +import SciLean.Lean.Meta.Basic +import SciLean.Data.ArraySet + +open Lean Meta + +namespace SciLean + + +inductive HeadFunInfo + | const (constName : Name) (arity : Nat) + | fvar (id : FVarId) (arity : Nat) + | bvar (i : Nat) (arity : Nat) + + +def HeadFunInfo.arity (info : HeadFunInfo) : Nat := + match info with + | .const _ n => n + | .fvar _ n => n + | .bvar _ n => n + +def HeadFunInfo.ctorName (info : HeadFunInfo) : Name := + match info with + | .const _ _ => ``const + | .fvar _ _ => ``fvar + | .bvar _ _ => ``bvar + +def HeadFunInfo.isFVar (info : HeadFunInfo) (id : FVarId) : Bool := + match info with + | .const _ _ => false + | .fvar id' _ => id == id' + | .bvar _ _ => false + + +inductive MainArgCase where + /-- there are no main arguments -/ + | noMainArg + /-- Main arguments are just `x` i.e. `x = (a'₁, ..., a'ₖ)` where `a' = mainIds.map (fun i => aᵢ)` are main arguments -/ + | trivialUncurried + /-- Main arguments are just functions of `x` and do not depend on `yⱼ` + + This allows to write the lambda function as composition + ``` + fun x y₀ ... yₙ₋₁ => f a₀ ... aₘ₋₁ + = + f' ∘ g' + = + (fun (a'₁, ..., a'ₖ) y₀ ... yₙ₋₁ => f a₀ ... aₘ₋₁) ∘ (fun x => (a'₁, ..., a'ₖ)) + ``` + where the function `f'` is in `MainArgCase.trivialUncurried` case -/ + | nonTrivailNoTrailing + /-- Main arguments depend on `x` and `yⱼ` -/ + | nonTrivialWithTrailing +deriving DecidableEq, Repr + +inductive TrailingArgCase where + /-- there are no trailing arguments -/ + | noTrailingArg + + /-- Trailing arguments are exactly equal to `yⱼ` + i.e. `yⱼ = a''ⱼ` where `a'' := trailingArgs.map (fun i => aᵢ)` -/ + | trivial + + /-- Traling arguments are just `y₀` i.e. `n=1` and `y₀ = (a''₁, ..., a''ₖ)` + where `a'' := trailingIds.map (fun i => aᵢ)` + + It is guaranteed that `k>1`, when `k=1` then we are in `TrailingArgCase.trivial` case -/ + | trivialUncurried + + /-- Trailing arguments are non trivial function of `yⱼ` + + In this case we usually want to find inverse map `h` mapping `a''` to `yⱼ` + ``` + fun x y₀ ... yₙ₋₁ => f a₀ ... aₘ₋₁ + = + (·∘h) ∘ f' + = + (·∘h) ∘ (fun x a''₁ ... a''ₖ => f a₀ ... aₘ₋₁ + ``` + where the function `f'` is now in `TrailingArgCase.trivial` case + (constructing such `f'` is a bit tricky as it potentially requires to also + use `h` to replace `yⱼ` with `a''` in main arguments) + -/ + | nonTrivial +deriving DecidableEq, Repr + + +/-- Info about lambda function `fun x y₀ ... yₙ₋₁ => f a₀ ... aₘ₋₁` +-/ +structure LambdaInfo where + -- /-- the lambda function itself -/ + -- fn : Expr + /-- number of lambda binders -/ + arity : Nat -- n+1 + /-- number of function arguments in the body -/ + argNum : Nat -- m + /-- info on the head function `f` -/ + headFunInfo : HeadFunInfo + /-- Set of argument indices `i` saying that `aᵢ` depends on `x`, they might depend `yⱼ` too -/ + mainIds : ArraySet Nat + /-- Set of argument indices `i` saying that `aᵢ` depends on at least one of `yⱼ` but not on `x` -/ + trailingIds : ArraySet Nat + -- /-- Set of argument indices `i` saying that `aᵢ` does not depend of `x` or `yⱼ` + -- This is a complement of `mainIds ∪ trailinIds` -/ + -- unusedIds : ArraySet Nat + mainArgCase : MainArgCase + trailingArgCase : TrailingArgCase + + +/-- Analyze head function `f` of lambda `fun x₁ ... xₙ => f ...` where `xs = #[x₁, ..., xₙ]` + +Returns `HeadFunInfo.bvar` if the head function is fvar and one of `xs` +-/ +private def analyzeHeadFun (fn : Expr) (xs : Array Expr) : MetaM HeadFunInfo := do + match fn with + | .const name _ => + pure (.const name (← getConstArity name)) + | .fvar id => + let arity := (← inferType fn).forallArity + if let .some i := xs.findIdx? (fun x => x.fvarId! == id) then + pure (.bvar i arity) + else + pure (.fvar id arity) + | _ => throwError s!"invalid head function {← ppExpr fn}" + +/-- +Decompose function as `fun x i₁ ... iₙ => f (g x) (h i₁ ... iₙ)` + +`f = fun y₁ ... yₘ i₁ ... iₙ => f' y₁ ... yₘ` +-/ +partial def analyzeLambda (e : Expr) : MetaM LambdaInfo := do + IO.println s!"analyzing {← ppExpr e}" + lambdaTelescope e fun xs body => do + + -- if `body` is a projection turn it into application of projection function + let body := (← revertStructureProj body).headBeta + + + let fn := body.getAppFn' + let args := body.getAppArgs + + let fnInfo ← analyzeHeadFun fn xs + + let x := xs[0]! + let xId := x.fvarId! + -- let xName ← xId.getUserName + let ys := xs[1:].toArray + + let mut as' : Array Expr := #[] + let mut as'' : Array Expr := #[] + + let mut mainIds : Array Nat := #[] + let mut trailingIds : Array Nat := #[] + + let mut mainCase : MainArgCase := .noMainArg + let mut trailingCase : TrailingArgCase := .noTrailingArg + + for arg in args, i in [0:args.size] do + let ys' := ys.filter (fun y => arg.containsFVar y.fvarId!) + + if arg.containsFVar xId then + mainIds := mainIds.push i + as' := as'.push arg + if ys'.size ≠ 0 then + mainCase := .nonTrivialWithTrailing + else if ys'.size ≠ 0 then + trailingIds := trailingIds.push i + as'' := as''.push arg + + + -- determina main arg case + let a' ← mkProdElem as' + if as'.size ≠ 0 && mainCase ≠ .nonTrivialWithTrailing then + if (← isDefEq x a') then + mainCase := .trivialUncurried + else + mainCase := .nonTrivailNoTrailing + + -- determina trailing arg case + if as''.size ≠ 0 then + trailingCase := .nonTrivial + + if ys.size = as''.size then + if ← (Array.range ys.size).allM (fun i => isDefEq ys[i]! as''[i]!) then + trailingCase := .trivial + + if ys.size = 1 && as''.size > 1 then + let a'' ← mkProdElem as'' + if ← isDefEq ys[0]! a'' then + trailingCase := .trivialUncurried + + + + return { + arity := xs.size + argNum := args.size + headFunInfo := fnInfo + mainIds := mainIds.toArraySet + trailingIds := trailingIds.toArraySet + mainArgCase := mainCase + trailingArgCase := trailingCase + } + +open Qq + + +def LambdaInfo.print (info : LambdaInfo) : IO Unit := do + IO.println s!"arity: {info.arity}" + IO.println s!"argNum: {info.argNum}" + IO.println s!"headFunction ctor: {info.headFunInfo.ctorName}" + IO.println s!"headFunction arity: {info.headFunInfo.arity}" + IO.println s!"main ids: {info.mainIds}" + IO.println s!"trailing ids: {info.trailingIds}" + IO.println s!"main arg case: {repr info.mainArgCase}" + IO.println s!"trailing arg case: {repr info.trailingArgCase}" + + +#eval show MetaM Unit from do + withLocalDeclQ `f .default q(Float → Float → Float) fun f => do + let e := q(fun (x,y) (z,h) => $f h ($f z ($f x y))) + IO.println (← ppExpr e) + let e ← lambdaTelescope e fun xs b => do mkLambdaFVars xs (← whnf b) + IO.println (← ppExpr e) + + +#eval show MetaM Unit from do + withLocalDeclQ `x .default q(Float×Float) fun x => do + let e ← mkLambdaFVars #[x] (Expr.proj ``Prod 0 x) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + let e := q(fun (x : Float × Float) => x.1) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + let e := q(fun (x : (Fin 10 → Float) × (Fin 5 → Float)) => x.1) + let info ← analyzeLambda (← etaExpand' e) + info.print + +#eval show MetaM Unit from do + let e := q(fun (x : (Fin 10 → Float) × (Fin 5 → Float)) i => x.1 i) + let info ← analyzeLambda e + info.print + + +#eval show MetaM Unit from do + let e := q(fun (x : (Fin 10 × Fin 20 → Float) × (Fin 5 → Float)) i j => x.1 (i,j)) + let info ← analyzeLambda e + info.print + + +#eval show MetaM Unit from do + let e := q(fun (A : (Fin 10 → Fin 5 → Float)) (ij : Fin 10 × Fin 5) => A ij.1 ij.2) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + let e := q(fun (A : (Fin 10 → Fin 5 → Float)) i j => A i j) + let info ← analyzeLambda e + info.print + + #eval show MetaM Unit from do + let e := q(fun (A : (Fin 10 → Fin 5 → Float)) j i => A i j) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + let e := q(@Prod.fst (Fin 10 → Float) (Fin 5 → Float)) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + let e := q(fun (x : (Fin 10 → Float) × (Fin 5 → Float)) (i : Fin 10 × Fin 5) => x.1 i.1) + let info ← analyzeLambda e + info.print + +#eval show MetaM Unit from do + + withLocalDeclQ `i₂ .default q(Fin 2) fun i₂ => do + withLocalDeclQ `i₃ .default q(Fin 3) fun i₃ => do + withLocalDeclQ `i₅ .default q(Fin 5) fun i₅ => do + let e := q(fun (A : (Fin 1 → Fin 2 → Fin 3 → Fin 4 → Fin 5 → Fin 6 → Float)) i₁ i₄ i₆ => A i₁ $i₂ $i₃ i₄ $i₅ i₆) + IO.println (← ppExpr e.eta) + let _ ← analyzeLambda e + + +#eval show MetaM Unit from do + + withLocalDeclQ `i₂ .default q(Fin 2) fun i₂ => do + withLocalDeclQ `i₃ .default q(Fin 3) fun i₃ => do + withLocalDeclQ `i₅ .default q(Fin 5) fun i₅ => do + let e := q(fun (A : (Fin 1 → Fin 2 → Fin 3 → Fin 4 → Fin 5 → Fin 6 → Float)) (i : Fin 1 × Fin 4 × Fin 6) => A i.1 $i₂ $i₃ i.2.1 $i₅ i.2.2) + + let _ ← analyzeLambda e + + +#eval show MetaM Unit from do + + withLocalDeclQ `i₂ .default q(Fin 2) fun i₂ => do + withLocalDeclQ `i₃ .default q(Fin 3) fun i₃ => do + withLocalDeclQ `i₅ .default q(Fin 5) fun i₅ => do + let e := q(fun (A : (Fin 1 → Fin 2 → Fin 3 → Fin 4 → Fin 5 → Fin 6 → Float)) (i : Fin 4 × Fin 1 × Fin 6) => A i.2.1 $i₂ $i₃ i.1 $i₅ i.2.2) + + let _ ← analyzeLambda e + diff --git a/SciLean/Tactic/FProp/Basic.lean b/SciLean/Tactic/FProp/Basic.lean index 51b55096..db13e837 100644 --- a/SciLean/Tactic/FProp/Basic.lean +++ b/SciLean/Tactic/FProp/Basic.lean @@ -1,4 +1,5 @@ import SciLean.Tactic.FProp.Init +import SciLean.Tactic.AnalyzeLambda open Lean Meta Qq @@ -129,7 +130,18 @@ def getLocalRules (fpropName : Name) : MetaM (Array SimpTheorem) := do return arr - +structure LocalRule where + fvar : FVarId + proof : Expr + mainIds : ArraySet Nat + trailingIds : ArraySet Nat + +def toFullyAppliedForm (f : Expr) : MetaM Expr := do + lambdaTelescope f fun xs b => do + let b ← whnf b + withDefault do forallTelescopeReducing (← inferType b) fun xs' _ => + mkLambdaFVars (xs++xs') (mkAppN b xs').headBeta + def tryLocalTheorems (e : Expr) (fpropName : Name) (ext : FPropExt) (fprop : Expr → FPropM (Option Expr)) : FPropM (Option Expr) := do @@ -142,6 +154,55 @@ def tryLocalTheorems (e : Expr) (fpropName : Name) (ext : FPropExt) return none +def getLocalRulesForFVar (fId : FVarId) (fpropName : Name) (ext : FPropExt) : MetaM (Array LocalRule) := do + + let mut arr : Array LocalRule := #[] + + let lctx ← getLCtx + for var in lctx do + if (var.kind = Lean.LocalDeclKind.auxDecl) then + continue + + let type ← instantiateMVars var.type + + let rule? : Option LocalRule ← + forallTelescopeReducing var.type fun xs type => do + if ¬(type.isAppOf' fpropName) then + return none + let .some f := ext.getFPropFun? type + | return none + let f ← toFullyAppliedForm f + let info ← analyzeLambda f + if (info.headFunInfo.isFVar fId) then + return .some { + fvar := var.fvarId + proof := var.toExpr + mainIds := info.mainIds + trailingIds := info.trailingIds + } + pure none + + let .some rule := rule? + | continue + + arr := arr.push rule + + return arr + +-- def tryLocalTheoremsForFVar (e : Expr) (fpropName : Name) (ext : FPropExt) + +-- (fprop : Expr → FPropM (Option Expr)) +-- : FPropM (Option Expr) := do + +-- let candidates ← getLocalRules fpropName + +-- for thm in candidates do +-- if let some proof ← tryTheorem?' e thm ext.discharger fprop then +-- return proof + +-- return none + + def unfoldFunHead? (e : Expr) : MetaM (Option Expr) := do lambdaLetTelescope e fun xs b => do @@ -201,19 +262,28 @@ def fvarAppCase (e : Expr) (fpropName : Name) (ext : FPropExt) (f : Expr) -- trivial case, this prevents an infinite loop if (← isDefEq f' f) then - -- this is a bit of a hack - if let .some (f', g') ← evalSplit f then - trace[Meta.Tactic.fprop.step] "fvar app case: decomposed into `({← ppExpr f'}) ∘ ({← ppExpr g'})`" - let step? ← - try - ext.compRule e f' g' - catch e => - pure none - let .some step := step? | pure () - return step + -- -- this is a bit of a hack + -- if let .some (f', g') ← evalSplit f then + -- trace[Meta.Tactic.fprop.step] "fvar app case: decomposed into `({← ppExpr f'}) ∘ ({← ppExpr g'})`" + -- let step? ← + -- try + -- ext.compRule e f' g' + -- catch e => + -- pure none + -- let .some step := step? | pure () + -- return step trace[Meta.Tactic.fprop.step] "fvar app case: trivial" - tryLocalTheorems e fpropName ext fprop + let step? ← tryLocalTheorems e fpropName ext fprop + + if let .some step := step? then + return step + + if let .some (_,Y) := (← inferType f).arrow? then + if Y.isForall then + return ← fprop (ext.replaceFPropFun e (← etaExpand f)) + + return none else trace[Meta.Tactic.fprop.step] "fvar app case: decomposed into `({← ppExpr f'}) ∘ ({← ppExpr g'})`" ext.compRule e f' g' @@ -409,6 +479,9 @@ mutual | .mvar _ => do fprop (← instantiateMVars e) + | .fvar _ => do + fprop (ext.replaceFPropFun e (← etaExpand f)) + | .proj typeName idx _ => do let .some info := getStructureInfo? (← getEnv) typeName | return none let .some projName := info.getProjFn? idx | return none diff --git a/test/generate_ftrans.lean b/test/generate_ftrans.lean index 9755f543..729c708e 100644 --- a/test/generate_ftrans.lean +++ b/test/generate_ftrans.lean @@ -108,6 +108,27 @@ def matmul (A : ι → κ → K) (x : κ → K) (i : ι) : K := ∑ j, A i j * prop_by unfold matmul; fprop trans_by unfold matmul; autodiff; autodiff + +-- set_option trace.Meta.Tactic.ftrans.step true in +-- #check +-- (∂> (x : Fin 10 → Fin 10 → K), x 1 4) +-- rewrite_by +-- ftrans only + + +-- #generate_fwdCDeriv matmul A x +-- prop_by unfold matmul; fprop +-- trans_by unfold matmul; autodiff; autodiff + +-- #generate_fwdCDeriv matmul A | i +-- prop_by unfold matmul; fprop +-- trans_by unfold matmul; ftrans only; autodiff; autodiff + +-- #generate_fwdCDeriv matmul x | i +-- prop_by unfold matmul; fprop +-- trans_by unfold matmul; ftrans only; autodiff; autodiff + + -- TODO: right name is not being generated!!! -- it should be `matmul.arg_A_i.revCDeriv` -- #check matmul.arg_A.revCDeriv diff --git a/test/issues/23.lean b/test/issues/23.lean index 5d038108..c787d24d 100644 --- a/test/issues/23.lean +++ b/test/issues/23.lean @@ -6,28 +6,40 @@ variable {K : Type _} [IsROrC K] {X : Type _} [Vec K X] --- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) --- : IsDifferentiable K f := by fprop --- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) --- : IsDifferentiable K (fun x => f x i j) := by fprop +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) + : IsDifferentiable K f := by fprop --- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) --- : IsDifferentiable K (fun x => f x) := by fprop +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) + : IsDifferentiable K (fun x => f x i j) := by fprop --- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) --- : IsDifferentiable K (fun x i j => f x i j) := by fprop +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) + : IsDifferentiable K (fun x => f x) := by fprop + +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) + : IsDifferentiable K (fun x i j => f x i j) := by fprop + +example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) + : IsDifferentiable K (fun x i j k => f x i j k) := by fprop --- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j k, IsDifferentiable K (f · i j k)) --- : IsDifferentiable K (fun x i j k => f x i j k) := by fprop +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j, IsDifferentiable K (f · i j)) +-- : IsDifferentiable K (fun x i => f x i) := by fprop +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i j, IsDifferentiable K (f · i j)) (j k) +-- : IsDifferentiable K (fun x i => f x i j k) := by fprop -example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (i j k) - : IsDifferentiable K (fun x => f x i j k) := by fprop +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i k, IsDifferentiable K (fun j => f · i j k)) (j k) +-- : IsDifferentiable K (fun x i => f x i j k) := by fprop + +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : ∀ i k, IsDifferentiable K (fun j => f · i j k)) (j k) +-- : IsDifferentiable K (fun x i j => f x i j) := by fprop + +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (i j k) +-- : IsDifferentiable K (fun x => f x i j k) := by fprop -example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (j k) - : IsDifferentiable K (fun x i => f x i j k) := by fprop +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (j k) +-- : IsDifferentiable K (fun x i => f x i j k) := by fprop -example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (j) - : IsDifferentiable K (fun x i k => f x i j k) := by fprop +-- example (f : X → Fin 5 → Fin 10 → Fin 15→ K) (hf : IsDifferentiable K f) (j) +-- : IsDifferentiable K (fun x i k => f x i j k) := by fprop