Skip to content

Commit

Permalink
improve type inference in notation for derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 20, 2024
1 parent 29bbd12 commit 6754d6e
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
4 changes: 3 additions & 1 deletion SciLean/Analysis/Calculus/Notation/FwdDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ syntax "∂>! " "(" diffBinder ")" ", " term:66 : term
open Lean Elab Term Meta in
elab_rules : term
| `(∂> $f $x $xs*) => do
elabTerm (← `(fwdFDeriv defaultScalar% $f $x $xs*)) none
let X ← inferType (← elabTerm x none)
let sX ← exprToSyntax X
elabTerm (← `(fwdFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none

| `(∂> $f) => do
elabTerm (← `(fwdFDeriv defaultScalar% $f)) none
Expand Down
6 changes: 4 additions & 2 deletions SciLean/Analysis/Calculus/Notation/Gradient.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ elab_rules (kind:=gradNotation1) : term
let XY ← mkArrow X Y
-- Y might also be infered by the function `f`
let fExpr ← withoutPostponing <| elabTermEnsuringType f XY false
let sX ← exprToSyntax X
let .some (_,Y) := (← inferType fExpr).arrow?
| return ← throwUnsupportedSyntax
if (← isDefEq K Y) then
elabTerm (← `(fgradient $f $x $xs*)) none false
elabTerm (← `(fgradient (X:=$sX) $f $x $xs*)) none false
else
elabTerm (← `(adjointFDeriv defaultScalar% $f $x $xs*)) none false
elabTerm (← `(adjointFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none false


| `(∇ $f) => do
let K ← elabTerm (← `(defaultScalar%)) none
Expand Down
6 changes: 4 additions & 2 deletions SciLean/Analysis/Calculus/Notation/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ syntax "<∂! " "(" diffBinder ")" ", " term:66 : term

open Lean Elab Term Meta in
elab_rules : term
| `(<∂ $f $xs*) => do
elabTerm (← `(revFDeriv defaultScalar% $f $xs*)) none
| `(<∂ $f $x $xs*) => do
let X ← inferType (← elabTerm x none)
let sX ← exprToSyntax X
elabTerm (← `(revFDeriv (X:=$sX) defaultScalar% $f $x $xs*)) none
| `(<∂ $f) => do
elabTerm (← `(revFDeriv defaultScalar% $f)) none

Expand Down

0 comments on commit 6754d6e

Please sign in to comment.