From 31ca09d794720fda3afd07b83a10a8a63fd9493a Mon Sep 17 00:00:00 2001 From: lecopivo Date: Thu, 21 Sep 2023 13:36:31 -0400 Subject: [PATCH] reworked how suggested name works for `ftrans` attribute --- .../FunctionPropositions/Diffeomorphism.lean | 2 +- .../Core/FunctionTransformations/CDeriv.lean | 2 +- SciLean/Core/Meta/GenerateRevCDeriv.lean | 2 +- SciLean/Core/Monads/FwdDerivMonad.lean | 2 +- SciLean/Core/Monads/Id.lean | 2 +- SciLean/Data/ArrayType/Properties.lean | 2 +- SciLean/Lean/Array.lean | 3 + SciLean/Lean/Expr.lean | 7 + SciLean/Lean/Meta/Basic.lean | 2 +- SciLean/Tactic/FTrans/Init.lean | 198 ++++++++++-------- 10 files changed, 126 insertions(+), 96 deletions(-) diff --git a/SciLean/Core/FunctionPropositions/Diffeomorphism.lean b/SciLean/Core/FunctionPropositions/Diffeomorphism.lean index cd4a98aa..2fbbecbc 100644 --- a/SciLean/Core/FunctionPropositions/Diffeomorphism.lean +++ b/SciLean/Core/FunctionPropositions/Diffeomorphism.lean @@ -305,7 +305,7 @@ by -- Which rule is preferable? This one or the second one? -- Probably the second as it has Function.inv fully applied @[ftrans] -theorem Function.invFun.arg_f.cderiv_rule +theorem Function.invFun.arg_f_a1.cderiv_rule (f : X → Y → Z) (hf : ∀ x, Diffeomorphism K (f x)) (hf' : IsDifferentiable K (fun xy : X×Y => f xy.1 xy.2)) diff --git a/SciLean/Core/FunctionTransformations/CDeriv.lean b/SciLean/Core/FunctionTransformations/CDeriv.lean index 806e7a8b..3f2af6fe 100644 --- a/SciLean/Core/FunctionTransformations/CDeriv.lean +++ b/SciLean/Core/FunctionTransformations/CDeriv.lean @@ -350,7 +350,7 @@ by unfold Function.comp; ftrans @[ftrans] -theorem Function.comp.arg_fg.cderiv_rule +theorem Function.comp.arg_fg_a0.cderiv_rule (f : W → Y → Z) (g : W → X → Y) (hf : IsDifferentiable K (fun wy : W×Y => f wy.1 wy.2)) (hg : IsDifferentiable K (fun wx : W×X => g wx.1 wx.2)) diff --git a/SciLean/Core/Meta/GenerateRevCDeriv.lean b/SciLean/Core/Meta/GenerateRevCDeriv.lean index 7ac55d50..d4b41a3e 100644 --- a/SciLean/Core/Meta/GenerateRevCDeriv.lean +++ b/SciLean/Core/Meta/GenerateRevCDeriv.lean @@ -355,7 +355,7 @@ def generateHasAdjDiff (constName : Name) (mainNames trailingNames : Array Name) addDecl (.thmDecl info) - + FProp.funTransRuleAttr.attr.add name (← `(attr|fprop)) .global open Lean.Parser.Tactic.Conv diff --git a/SciLean/Core/Monads/FwdDerivMonad.lean b/SciLean/Core/Monads/FwdDerivMonad.lean index cc4ce4db..6ed189f5 100644 --- a/SciLean/Core/Monads/FwdDerivMonad.lean +++ b/SciLean/Core/Monads/FwdDerivMonad.lean @@ -514,7 +514,7 @@ by @[ftrans] -theorem Pure.pure.fwdDerivValM_rule (x : X) +theorem Pure.pure.arg.fwdDerivValM_rule (x : X) : fwdDerivValM K (pure (f:=m) x) = pure (x,0) := diff --git a/SciLean/Core/Monads/Id.lean b/SciLean/Core/Monads/Id.lean index b9abb672..ac00bc41 100644 --- a/SciLean/Core/Monads/Id.lean +++ b/SciLean/Core/Monads/Id.lean @@ -79,7 +79,7 @@ variable @[fprop] -theorem Id.run.arg_x.HasAdjDiffM_rule +theorem Id.run.arg_x.HasAdjDiff_rule (a : X → Id Y) (ha : HasAdjDiffM K a) : HasAdjDiff K (fun x => Id.run (a x)) := ha diff --git a/SciLean/Data/ArrayType/Properties.lean b/SciLean/Data/ArrayType/Properties.lean index 55d1d1b2..8d4da76a 100644 --- a/SciLean/Data/ArrayType/Properties.lean +++ b/SciLean/Data/ArrayType/Properties.lean @@ -158,7 +158,7 @@ by unfold revCDeriv; ftrans; ftrans; simp @[ftrans] -theorem GetElem.getElem.arg_xs_idx.revCDeriv_rule +theorem GetElem.getElem.arg_xs_i.revCDeriv_rule (f : X → Cont) (dom) (hf : HasAdjDiff K f) : revCDeriv K (fun x idx => getElem (f x) idx dom) diff --git a/SciLean/Lean/Array.lean b/SciLean/Lean/Array.lean index f4080101..a46e0772 100644 --- a/SciLean/Lean/Array.lean +++ b/SciLean/Lean/Array.lean @@ -24,6 +24,9 @@ do idx := idx + 1 return (bs, cs, ids) +def splitM {α : Type _} {m : Type _ → Type _} [Monad m] (as : Array α) (p : α → m Bool) : m (Array α × Array α) := + as.foldlM (init := (#[], #[])) fun (as, bs) a => do + pure <| if ← p a then (as.push a, bs) else (as, bs.push a) /-- Splits array into two based on function p. It also returns indices that can be used to merge two array back together. -/ diff --git a/SciLean/Lean/Expr.lean b/SciLean/Lean/Expr.lean index bb9648d9..97f1e850 100644 --- a/SciLean/Lean/Expr.lean +++ b/SciLean/Lean/Expr.lean @@ -300,3 +300,10 @@ def bindingBodyRec : Expr → Expr | .forallE _ _ b _ => b.bindingBodyRec | .mdata _ e => e.bindingBodyRec | e => e + + +def letBodyRec' (e : Expr) : Expr := + match e with + | .letE _ _ _ b _ => b.letBodyRec' + | .mdata _ e => e.letBodyRec' + | e => e diff --git a/SciLean/Lean/Meta/Basic.lean b/SciLean/Lean/Meta/Basic.lean index 62216870..947ed220 100644 --- a/SciLean/Lean/Meta/Basic.lean +++ b/SciLean/Lean/Meta/Basic.lean @@ -62,7 +62,7 @@ def getFunHeadConst? (e : Expr) : MetaM (Option Name) := match e.consumeMData with | .const name _ => return name | .app f _ => return f.getAppFn'.constName? - | .lam _ _ b _ => return b.getAppFn'.constName? + | .lam _ _ b _ => return b.letBodyRec'.getAppFn'.constName? | .proj structName idx _ => do let .some info := getStructureInfo? (← getEnv) structName | return none diff --git a/SciLean/Tactic/FTrans/Init.lean b/SciLean/Tactic/FTrans/Init.lean index 7514a577..10e0e570 100644 --- a/SciLean/Tactic/FTrans/Init.lean +++ b/SciLean/Tactic/FTrans/Init.lean @@ -195,6 +195,88 @@ def getFTransFun? (e : Expr) : CoreM (Option Expr) := do -------------------------------------------------------------------------------- -------------------------------------------------------------------------------- +structure ConstLambdaData where + constName : Name + mainArgs : Array Name + unusedArgs : Array Name + trailingArgs : Array Name + mainIds : ArraySet Nat + trailingIds : ArraySet Nat + declSuffix : String + + +/-- Analyze function appearing in fprop rules or on lhs of ftrans rules -/ +def analyzeConstLambda (e : Expr) : MetaM ConstLambdaData := do + let e ← zetaReduce e + lambdaTelescope e fun xs b => do + + let constName ← + match b.getAppFn' with + | .const name _ => pure name + | _ => throwError "{← ppExpr b.getAppFn'} is expected to be a constant" + + let args := b.getAppArgs + let argNames ← getConstArgNames constName true + + -- check we have enough arguments + if args.size < argNames.size then + throwError "expression {← ppExpr b} is not fully applied, missing arguments {argNames[args.size:]}" + if args.size > argNames.size then + throwError "expression {← ppExpr b} is overly applied, surplus arguments {args[argNames.size:]}" + + if xs.size = 0 then + return { + constName := constName + mainArgs := #[] + unusedArgs := argNames + trailingArgs := #[] + mainIds := #[].toArraySet + trailingIds := #[].toArraySet + declSuffix := "arg" + } + + let x := xs[0]! + let xId := x.fvarId! + let ys := xs[1:].toArray + + let mut main : Array Name := #[] + let mut unused : Array Name := #[] + let mut trailing : Array Name := #[] + let mut mainIds : Array Nat := #[] + let mut trailingIds : Array Nat := #[] + + for arg in args, i in [0:args.size] do + let ys' := ys.filter (fun y => arg.containsFVar y.fvarId!) + if ys'.size > 0 then + if arg != ys'[0]! then + throwError "invalid argument {← ppExpr arg}, trailing arguments {← ys.mapM ppExpr} can appear only as they are" + trailing := trailing.push argNames[i]! + trailingIds := trailingIds.push i + else if arg.containsFVar xId then + main := main.push argNames[i]! + mainIds := mainIds.push i + else + unused := unused.push argNames[i]! + + let mut declSuffix := "arg" + if main.size ≠ 0 then + declSuffix := declSuffix ++ "_" ++ main.joinl (fun n => toString n) (·++·) + if trailing.size ≠ 0 then + declSuffix := declSuffix ++ "_" ++ trailing.joinl (fun n => toString n) (·++·) + + return { + constName := constName + mainArgs := main + unusedArgs := unused + trailingArgs := trailing + mainIds := mainIds.toArraySet + trailingIds := trailingIds.toArraySet + declSuffix := declSuffix + } + +-------------------------------------------------------------------------------- + + initialize registerTraceClass `trace.Tactic.ftrans.new_property structure FTransRule where @@ -202,25 +284,25 @@ structure FTransRule where -- constName : Name ruleName : Name priority : Nat := 1000 - /-- Set of active argument indices in this rule + /-- Set of main argument indices in this rule For example: - - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) (g x)) = ...` has `argIds = #[4,5]` - - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) y) = ...` has `argIds = #[4]` -/ - argIds : ArraySet Nat + - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) (g x)) = ...` has `mainIds = #[4,5]` + - rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) y) = ...` has `mainIds = #[4]` -/ + mainIds : ArraySet Nat /-- Set of trailing argument indices in this rule For example: - rule `∂ (fun x i => @getElem _ _ _ _ _ (f x) i dom ` has `piArgs = #[6]` - rule `∂ (fun f x => @Function.invFun _ _ _ f x` has `piArgs = #[4]` - rule `∂ (fun x => (f x) + (g x)` has `piArgs = #[]` -/ - piIds : ArraySet Nat + trailingIds : ArraySet Nat def FTransRule.cmp (a b : FTransRule) : Ordering := - match a.piIds.lexOrd b.piIds with + match a.trailingIds.lexOrd b.trailingIds with | .lt => .lt | .gt => .gt | .eq => - match a.argIds.lexOrd b.argIds with + match a.mainIds.lexOrd b.mainIds with | .lt => .lt | .gt => .gt | .eq => @@ -287,86 +369,24 @@ To register function transformation call: ``` where is name of the function transformation and is corresponding `FTrans.Info`. " - - -- in rare cases `f` is not a function - -- for example this is case for monadic `fwdDerivValM` - if ¬f.isLambda then - let .const funName _ := f.getAppFn - | throwError "Function being transformed is in invalid form! The head of {← ppExpr f} is not a constant but it is {f.ctorName}!" - - if (← inferType f).isForall then - throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!" - - let ftransRule : FTransRule := { - ruleName := ruleName - argIds := #[].toArraySet - piIds := #[].toArraySet - } - - FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule) - - else - - lambdaTelescope f fun xs b => do - let .some x := xs[0]? - | throwError "Function being transformed is in invalid form! It has to be a lambda function!" - let .const funName _ := b.getAppFn - | throwError "Function being transformed is in invalid form! The head of {← ppExpr b} is not a constant but it is {b.ctorName}!" - if xs.size > 2 then - throwError "Function being transformed is in invalid form! Only one trailing argument is currently supported!" - - let xId := x.fvarId! - - let arity ← getConstArity funName - let args := b.getAppArgs - - if args.size ≠ arity then - throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!" - - let argIds := - args.mapIdx (fun i arg => if arg.containsFVar xId then .some i.1 else none) - |>.filterMap id - - let piIds ← - if let .some y := xs[1]? then - let yId := y.fvarId! - let piArgs := - args.mapIdx (fun i arg => if arg.containsFVar yId then .some (i.1,arg) else none) - |>.filterMap id - if piArgs.size ≠ 1 then - throwError "Function being transformed is in invalid form! Trailing argument `{← ppExpr y}` can appear in only one argument, but it appears in `{← piArgs.mapM (fun (_,arg) => ppExpr arg)}`" - pure (piArgs.map (fun (i,_) => i)) - else - pure #[] - - let argNames ← getConstArgNames funName (fixAnonymousNames := true) - let depName := - argIds.map (fun i => argNames[i]?.getD default) - |>.foldl (·++·.toString) "" - - let piName := - piIds.map (fun i => argNames[i]?.getD default) - |>.foldl (·++·.toString) "" - - let argSuffix := - "arg" ++ if depName ≠ "" then "_" ++ depName else "" - ++ if piName ≠ "" then "_" ++ piName else "" - - let suggestedRuleName := - funName |>.append argSuffix - |>.append (transName.getString.append "_rule") - - if (← getBoolOption `linter.ftransDeclName true) && - ¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then - logWarning s!"suggested name for this rule is {suggestedRuleName}" - - let ftransRule : FTransRule := { - ruleName := ruleName - argIds := argIds.toArraySet - piIds := piIds.toArraySet - } - - FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule) + let data ← analyzeConstLambda f + + let suggestedRuleName := + data.constName + |>.append data.declSuffix + |>.append (transName.getString.append "_rule") + + if (← getBoolOption `linter.ftransDeclName true) && + ¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then + logWarning s!"suggested name for this rule is {suggestedRuleName}" + + let ftransRule : FTransRule := { + ruleName := ruleName + mainIds := data.mainIds + trailingIds := data.trailingIds + } + + FTransRulesExt.insert data.constName (FTransRules.empty.insert transName ftransRule) ) @@ -379,7 +399,7 @@ def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do | return #[] let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do - if r.piIds.size ≠ 0 then + if r.trailingIds.size ≠ 0 then return none else return .some { @@ -400,7 +420,7 @@ def getFTransPiRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := | return #[] let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do - if r.piIds.size = 0 then + if r.trailingIds.size = 0 then return none else return .some {