Skip to content

Commit

Permalink
bang version of gradient notation
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 27, 2023
1 parent 9c4714a commit d535f9c
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions SciLean/Core/Notation/Gradient.lean
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import SciLean.Core.FunctionTransformations.RevCDeriv
import SciLean.Core.Notation.Autodiff
import SciLean.Core.Notation.CDeriv


namespace SciLean.NotationOverField

scoped syntax "∇ " term:66 : term
scoped syntax (name:=gradNotation1) "∇ " term:66 : term
scoped syntax "∇ " diffBinder ", " term:66 : term
scoped syntax "∇ " "(" diffBinder ")" ", " term:66 : term
scoped syntax "∇! " term:66 : term
scoped syntax "∇! " diffBinder ", " term:66 : term
scoped syntax "∇! " "(" diffBinder ")" ", " term:66 : term


open Lean Elab Term Meta in
elab_rules : term
elab_rules (kind:=gradNotation1) : term
| `(∇ $f) => do
let K := mkIdent (← currentFieldName.get)
let KExpr ← elabTerm (← `($K)) none
Expand All @@ -21,16 +27,23 @@ elab_rules : term
else
throwUnsupportedSyntax

-- in this case we do not want to call scalarGradient
| `(∇ $x:ident := $val:term; $codir:term, $b) => do
let K := mkIdent (← currentFieldName.get)
elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false
-- open Lean Elab Term Meta in
-- elab_rules (kind:=gradNotation1) : term
-- | `(∇ $x:ident := $val:term; $codir:term, $b) => do
-- let K := mkIdent (← currentFieldName.get)
-- elabTerm (← `(gradient $K (fun $x => $b) $val $codir)) none false

macro_rules
| `(∇ $x:ident, $f) => `(∇ fun $x => $f)
| `(∇ $x:ident : $type:term, $f) => `(∇ fun $x : $type => $f)
| `(∇ $x:ident := $val:term, $f) => `((∇ fun $x => $f) $val)
| `(∇ ($b:diffBinder), $f) => `(∇ $b, $f)
| `(∇! $f) => `((∇ $f) rewrite_by autodiff)
| `(∇! $x:ident, $f) => `(∇! fun $x => $f)
| `(∇! $x:ident : $type:term, $f) => `(∇! fun $x : $type => $f)
| `(∇! $x:ident := $val:term, $f) => `((∇! fun $x => $f) $val)
| `(∇! ($b:diffBinder), $f) => `(∇! $b, $f)


@[app_unexpander gradient] def unexpandGradient : Lean.PrettyPrinter.Unexpander

Expand Down

0 comments on commit d535f9c

Please sign in to comment.