diff --git a/SciLean/Analysis/Calculus/Notation/FwdDeriv.lean b/SciLean/Analysis/Calculus/Notation/FwdDeriv.lean index bedc1538..89363a3f 100644 --- a/SciLean/Analysis/Calculus/Notation/FwdDeriv.lean +++ b/SciLean/Analysis/Calculus/Notation/FwdDeriv.lean @@ -21,7 +21,9 @@ syntax "∂>! " "(" diffBinder ")" ", " term:66 : term open Lean Elab Term Meta in elab_rules : term | `(∂> $f $x $xs*) => do - elabTerm (← `(fwdFDeriv defaultScalar% $f $x $xs*)) none + let X ← inferType (← elabTerm x none) + let sX ← exprToSyntax X + elabTerm (← `(fwdFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none | `(∂> $f) => do elabTerm (← `(fwdFDeriv defaultScalar% $f)) none diff --git a/SciLean/Analysis/Calculus/Notation/Gradient.lean b/SciLean/Analysis/Calculus/Notation/Gradient.lean index 1f0c2d53..151e0ea9 100644 --- a/SciLean/Analysis/Calculus/Notation/Gradient.lean +++ b/SciLean/Analysis/Calculus/Notation/Gradient.lean @@ -22,12 +22,14 @@ elab_rules (kind:=gradNotation1) : term let XY ← mkArrow X Y -- Y might also be infered by the function `f` let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false + let sX ← exprToSyntax X let .some (_,Y) := (← inferType fExpr).arrow? | return ← throwUnsupportedSyntax if (← isDefEq K Y) then - elabTerm (← `(fgradient $f $x $xs*)) none false + elabTerm (← `(fgradient (X:=$sX) $f $x $xs*)) none false else - elabTerm (← `(adjointFDeriv defaultScalar% $f $x $xs*)) none false + elabTerm (← `(adjointFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none false + | `(∇ $f) => do let K ← elabTerm (← `(defaultScalar%)) none diff --git a/SciLean/Analysis/Calculus/Notation/RevDeriv.lean b/SciLean/Analysis/Calculus/Notation/RevDeriv.lean index 76e4fd50..706f0b45 100644 --- a/SciLean/Analysis/Calculus/Notation/RevDeriv.lean +++ b/SciLean/Analysis/Calculus/Notation/RevDeriv.lean @@ -19,8 +19,10 @@ syntax "<∂! " "(" diffBinder ")" ", " term:66 : term open Lean Elab Term Meta in elab_rules : term -| `(<∂ $f $xs*) => do - elabTerm (← `(revFDeriv defaultScalar% $f $xs*)) none +| `(<∂ $f $x $xs*) => do + let X ← inferType (← elabTerm x none) + let sX ← exprToSyntax X + elabTerm (← `(revFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none | `(<∂ $f) => do elabTerm (← `(revFDeriv defaultScalar% $f)) none