diff --git a/SciLean/Core/Notation/Gradient.lean b/SciLean/Core/Notation/Gradient.lean index 9968f087..f9d1ece4 100644 --- a/SciLean/Core/Notation/Gradient.lean +++ b/SciLean/Core/Notation/Gradient.lean @@ -1,14 +1,20 @@ import SciLean.Core.FunctionTransformations.RevCDeriv +import SciLean.Core.Notation.Autodiff import SciLean.Core.Notation.CDeriv + namespace SciLean.NotationOverField -scoped syntax "∇ " term:66 : term +scoped syntax (name:=gradNotation1) "∇ " term:66 : term scoped syntax "∇ " diffBinder ", " term:66 : term scoped syntax "∇ " "(" diffBinder ")" ", " term:66 : term +scoped syntax "∇! " term:66 : term +scoped syntax "∇! " diffBinder ", " term:66 : term +scoped syntax "∇! " "(" diffBinder ")" ", " term:66 : term + open Lean Elab Term Meta in -elab_rules : term +elab_rules (kind:=gradNotation1) : term | `(∇ $f) => do let K := mkIdent (← currentFieldName.get) let KExpr ← elabTerm (← `($K)) none @@ -21,16 +27,23 @@ elab_rules : term else throwUnsupportedSyntax --- in this case we do not want to call scalarGradient -| `(∇ $x:ident := $val:term; $codir:term, $b) => do - let K := mkIdent (← currentFieldName.get) - elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false +-- open Lean Elab Term Meta in +-- elab_rules (kind:=gradNotation1) : term +-- | `(∇ $x:ident := $val:term; $codir:term, $b) => do +-- let K := mkIdent (← currentFieldName.get) +-- elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false macro_rules | `(∇ $x:ident, $f) => `(∇ fun $x => $f) | `(∇ $x:ident : $type:term, $f) => `(∇ fun $x : $type => $f) | `(∇ $x:ident := $val:term, $f) => `((∇ fun $x => $f) $val) | `(∇ ($b:diffBinder), $f) => `(∇ $b, $f) +| `(∇! $f) => `((∇ $f) rewrite_by autodiff) +| `(∇! $x:ident, $f) => `(∇! fun $x => $f) +| `(∇! $x:ident : $type:term, $f) => `(∇! fun $x : $type => $f) +| `(∇! $x:ident := $val:term, $f) => `((∇! fun $x => $f) $val) +| `(∇! ($b:diffBinder), $f) => `(∇! $b, $f) + @[app_unexpander gradient] def unexpandGradient : Lean.PrettyPrinter.Unexpander