Skip to content

Commit

Permalink
reworked how suggested name works for ftrans attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Sep 21, 2023
1 parent 0912ca1 commit 31ca09d
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 96 deletions.
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionPropositions/Diffeomorphism.lean
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ by
-- Which rule is preferable? This one or the second one?
-- Probably the second as it has Function.inv fully applied
@[ftrans]
theorem Function.invFun.arg_f.cderiv_rule
theorem Function.invFun.arg_f_a1.cderiv_rule
(f : X → Y → Z)
(hf : ∀ x, Diffeomorphism K (f x))
(hf' : IsDifferentiable K (fun xy : X×Y => f xy.1 xy.2))
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/FunctionTransformations/CDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ by
unfold Function.comp; ftrans

@[ftrans]
theorem Function.comp.arg_fg.cderiv_rule
theorem Function.comp.arg_fg_a0.cderiv_rule
(f : W → Y → Z) (g : W → X → Y)
(hf : IsDifferentiable K (fun wy : W×Y => f wy.1 wy.2))
(hg : IsDifferentiable K (fun wx : W×X => g wx.1 wx.2))
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/Meta/GenerateRevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def generateHasAdjDiff (constName : Name) (mainNames trailingNames : Array Name)

addDecl (.thmDecl info)


FProp.funTransRuleAttr.attr.add name (← `(attr|fprop)) .global


open Lean.Parser.Tactic.Conv
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/Monads/FwdDerivMonad.lean
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ by


@[ftrans]
theorem Pure.pure.fwdDerivValM_rule (x : X)
theorem Pure.pure.arg.fwdDerivValM_rule (x : X)
: fwdDerivValM K (pure (f:=m) x)
=
pure (x,0) :=
Expand Down
2 changes: 1 addition & 1 deletion SciLean/Core/Monads/Id.lean
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ variable


@[fprop]
theorem Id.run.arg_x.HasAdjDiffM_rule
theorem Id.run.arg_x.HasAdjDiff_rule
(a : X → Id Y) (ha : HasAdjDiffM K a)
: HasAdjDiff K (fun x => Id.run (a x)) := ha

Expand Down
2 changes: 1 addition & 1 deletion SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ by
unfold revCDeriv; ftrans; ftrans; simp

@[ftrans]
theorem GetElem.getElem.arg_xs_idx.revCDeriv_rule
theorem GetElem.getElem.arg_xs_i.revCDeriv_rule
(f : X → Cont) (dom)
(hf : HasAdjDiff K f)
: revCDeriv K (fun x idx => getElem (f x) idx dom)
Expand Down
3 changes: 3 additions & 0 deletions SciLean/Lean/Array.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ do
idx := idx + 1
return (bs, cs, ids)

def splitM {α : Type _} {m : Type _ → Type _} [Monad m] (as : Array α) (p : α → m Bool) : m (Array α × Array α) :=
as.foldlM (init := (#[], #[])) fun (as, bs) a => do
pure <| if ← p a then (as.push a, bs) else (as, bs.push a)

/-- Splits array into two based on function p. It also returns indices that can be used to merge two array back together.
-/
Expand Down
7 changes: 7 additions & 0 deletions SciLean/Lean/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,10 @@ def bindingBodyRec : Expr → Expr
| .forallE _ _ b _ => b.bindingBodyRec
| .mdata _ e => e.bindingBodyRec
| e => e


def letBodyRec' (e : Expr) : Expr :=
match e with
| .letE _ _ _ b _ => b.letBodyRec'
| .mdata _ e => e.letBodyRec'
| e => e
2 changes: 1 addition & 1 deletion SciLean/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def getFunHeadConst? (e : Expr) : MetaM (Option Name) :=
match e.consumeMData with
| .const name _ => return name
| .app f _ => return f.getAppFn'.constName?
| .lam _ _ b _ => return b.getAppFn'.constName?
| .lam _ _ b _ => return b.letBodyRec'.getAppFn'.constName?
| .proj structName idx _ => do
let .some info := getStructureInfo? (← getEnv) structName
| return none
Expand Down
198 changes: 109 additions & 89 deletions SciLean/Tactic/FTrans/Init.lean
Original file line number Diff line number Diff line change
Expand Up @@ -195,32 +195,114 @@ def getFTransFun? (e : Expr) : CoreM (Option Expr) := do
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

structure ConstLambdaData where
constName : Name
mainArgs : Array Name
unusedArgs : Array Name
trailingArgs : Array Name
mainIds : ArraySet Nat
trailingIds : ArraySet Nat
declSuffix : String


/-- Analyze function appearing in fprop rules or on lhs of ftrans rules -/
def analyzeConstLambda (e : Expr) : MetaM ConstLambdaData := do
let e ← zetaReduce e
lambdaTelescope e fun xs b => do

let constName ←
match b.getAppFn' with
| .const name _ => pure name
| _ => throwError "{← ppExpr b.getAppFn'} is expected to be a constant"

let args := b.getAppArgs
let argNames ← getConstArgNames constName true

-- check we have enough arguments
if args.size < argNames.size then
throwError "expression {← ppExpr b} is not fully applied, missing arguments {argNames[args.size:]}"
if args.size > argNames.size then
throwError "expression {← ppExpr b} is overly applied, surplus arguments {args[argNames.size:]}"

if xs.size = 0 then
return {
constName := constName
mainArgs := #[]
unusedArgs := argNames
trailingArgs := #[]
mainIds := #[].toArraySet
trailingIds := #[].toArraySet
declSuffix := "arg"
}

let x := xs[0]!
let xId := x.fvarId!
let ys := xs[1:].toArray

let mut main : Array Name := #[]
let mut unused : Array Name := #[]
let mut trailing : Array Name := #[]
let mut mainIds : Array Nat := #[]
let mut trailingIds : Array Nat := #[]

for arg in args, i in [0:args.size] do
let ys' := ys.filter (fun y => arg.containsFVar y.fvarId!)
if ys'.size > 0 then
if arg != ys'[0]! then
throwError "invalid argument {← ppExpr arg}, trailing arguments {← ys.mapM ppExpr} can appear only as they are"
trailing := trailing.push argNames[i]!
trailingIds := trailingIds.push i
else if arg.containsFVar xId then
main := main.push argNames[i]!
mainIds := mainIds.push i
else
unused := unused.push argNames[i]!

let mut declSuffix := "arg"
if main.size ≠ 0 then
declSuffix := declSuffix ++ "_" ++ main.joinl (fun n => toString n) (·++·)
if trailing.size ≠ 0 then
declSuffix := declSuffix ++ "_" ++ trailing.joinl (fun n => toString n) (·++·)

return {
constName := constName
mainArgs := main
unusedArgs := unused
trailingArgs := trailing
mainIds := mainIds.toArraySet
trailingIds := trailingIds.toArraySet
declSuffix := declSuffix
}

--------------------------------------------------------------------------------


initialize registerTraceClass `trace.Tactic.ftrans.new_property

structure FTransRule where
-- ftransName : Name
-- constName : Name
ruleName : Name
priority : Nat := 1000
/-- Set of active argument indices in this rule
/-- Set of main argument indices in this rule
For example:
- rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) (g x)) = ...` has `argIds = #[4,5]`
- rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) y) = ...` has `argIds = #[4]` -/
argIds : ArraySet Nat
- rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) (g x)) = ...` has `mainIds = #[4,5]`
- rule `∂ (fun x => @HAdd.hAdd _ _ _ _ (f x) y) = ...` has `mainIds = #[4]` -/
mainIds : ArraySet Nat
/-- Set of trailing argument indices in this rule
For example:
- rule `∂ (fun x i => @getElem _ _ _ _ _ (f x) i dom ` has `piArgs = #[6]`
- rule `∂ (fun f x => @Function.invFun _ _ _ f x` has `piArgs = #[4]`
- rule `∂ (fun x => (f x) + (g x)` has `piArgs = #[]`
-/
piIds : ArraySet Nat
trailingIds : ArraySet Nat

def FTransRule.cmp (a b : FTransRule) : Ordering :=
match a.piIds.lexOrd b.piIds with
match a.trailingIds.lexOrd b.trailingIds with
| .lt => .lt
| .gt => .gt
| .eq =>
match a.argIds.lexOrd b.argIds with
match a.mainIds.lexOrd b.mainIds with
| .lt => .lt
| .gt => .gt
| .eq =>
Expand Down Expand Up @@ -287,86 +369,24 @@ To register function transformation call:
```
where <name> is name of the function transformation and <info> is corresponding `FTrans.Info`.
"

-- in rare cases `f` is not a function
-- for example this is case for monadic `fwdDerivValM`
if ¬f.isLambda then
let .const funName _ := f.getAppFn
| throwError "Function being transformed is in invalid form! The head of {← ppExpr f} is not a constant but it is {f.ctorName}!"

if (← inferType f).isForall then
throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!"

let ftransRule : FTransRule := {
ruleName := ruleName
argIds := #[].toArraySet
piIds := #[].toArraySet
}

FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule)

else

lambdaTelescope f fun xs b => do
let .some x := xs[0]?
| throwError "Function being transformed is in invalid form! It has to be a lambda function!"
let .const funName _ := b.getAppFn
| throwError "Function being transformed is in invalid form! The head of {← ppExpr b} is not a constant but it is {b.ctorName}!"
if xs.size > 2 then
throwError "Function being transformed is in invalid form! Only one trailing argument is currently supported!"

let xId := x.fvarId!

let arity ← getConstArity funName
let args := b.getAppArgs

if args.size ≠ arity then
throwError "Function being transformed is in invalid form! Function has to appear in fully applied form!"

let argIds :=
args.mapIdx (fun i arg => if arg.containsFVar xId then .some i.1 else none)
|>.filterMap id

let piIds ←
if let .some y := xs[1]? then
let yId := y.fvarId!
let piArgs :=
args.mapIdx (fun i arg => if arg.containsFVar yId then .some (i.1,arg) else none)
|>.filterMap id
if piArgs.size ≠ 1 then
throwError "Function being transformed is in invalid form! Trailing argument `{← ppExpr y}` can appear in only one argument, but it appears in `{← piArgs.mapM (fun (_,arg) => ppExpr arg)}`"
pure (piArgs.map (fun (i,_) => i))
else
pure #[]

let argNames ← getConstArgNames funName (fixAnonymousNames := true)
let depName :=
argIds.map (fun i => argNames[i]?.getD default)
|>.foldl (·++·.toString) ""

let piName :=
piIds.map (fun i => argNames[i]?.getD default)
|>.foldl (·++·.toString) ""

let argSuffix :=
"arg" ++ if depName ≠ "" then "_" ++ depName else ""
++ if piName ≠ "" then "_" ++ piName else ""

let suggestedRuleName :=
funName |>.append argSuffix
|>.append (transName.getString.append "_rule")

if (← getBoolOption `linter.ftransDeclName true) &&
¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then
logWarning s!"suggested name for this rule is {suggestedRuleName}"

let ftransRule : FTransRule := {
ruleName := ruleName
argIds := argIds.toArraySet
piIds := piIds.toArraySet
}

FTransRulesExt.insert funName (FTransRules.empty.insert transName ftransRule)
let data ← analyzeConstLambda f

let suggestedRuleName :=
data.constName
|>.append data.declSuffix
|>.append (transName.getString.append "_rule")

if (← getBoolOption `linter.ftransDeclName true) &&
¬(suggestedRuleName.toString.isPrefixOf ruleName.toString) then
logWarning s!"suggested name for this rule is {suggestedRuleName}"

let ftransRule : FTransRule := {
ruleName := ruleName
mainIds := data.mainIds
trailingIds := data.trailingIds
}

FTransRulesExt.insert data.constName (FTransRules.empty.insert transName ftransRule)
)


Expand All @@ -379,7 +399,7 @@ def getFTransRules (funName ftransName : Name) : CoreM (Array SimpTheorem) := do
| return #[]

let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do
if r.piIds.size ≠ 0 then
if r.trailingIds.size ≠ 0 then
return none
else
return .some {
Expand All @@ -400,7 +420,7 @@ def getFTransPiRules (funName ftransName : Name) : CoreM (Array SimpTheorem) :=
| return #[]

let rules : List SimpTheorem ← rules.toList.filterMapM fun r => do
if r.piIds.size = 0 then
if r.trailingIds.size = 0 then
return none
else
return .some {
Expand Down

0 comments on commit 31ca09d

Please sign in to comment.