Skip to content

Commit

Permalink
basic tests for gradietn and reverse mode AD
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 31, 2023
1 parent 148ec62 commit 80b5e5f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 7 deletions.
1 change: 1 addition & 0 deletions SciLean/Core/Simp/Sum.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ namespace SciLean

variable {ι} [EnumType ι]

@[simp]
theorem sum_if {β : Type _} [AddCommMonoid β] (f : ι → β) (j : ι)
: (∑ i, if i = j then f i else 0)
=
Expand Down
7 changes: 0 additions & 7 deletions SciLean/Data/ArrayType/Notation.lean
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ def unexpandIntroElemNotation : Lean.PrettyPrinter.Unexpander
`(⊞ $x:ident => $b)
| _ => throw ()

/-- Convert `introElem` to `introElemNotation` if possible to get nicer pretty printing.
-/
@[simp]
theorem introElem_introElemNotation {Cont Idx Elem} [ArrayType Cont Idx Elem] [ArrayTypeNotation Cont Idx Elem] (f : Idx → Elem)
: introElem (Cont:=Cont) f = introElemNotation f := by rfl


open Lean Elab Term in
elab:40 (priority:=high) x:term:41 " ^ " y:term:42 : term =>
try
Expand Down
85 changes: 85 additions & 0 deletions test/basic_gradients.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import SciLean
import SciLean.Util.Profile
import SciLean.Tactic.LetNormalize
import SciLean.Util.RewriteBy

import SciLean.Core.Simp.Sum

open SciLean

variable
{K : Type} [RealScalar K]
{X : Type} [SemiInnerProductSpace K X]
{Y : Type} [SemiInnerProductSpace K Y]
{Z : Type} [SemiInnerProductSpace K Z]
{ι : Type} [EnumType ι]
{E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)]

set_default_scalar K

example
: (∇ (x : Fin 10 → K), fun i => x i)
=
fun x dx => dx :=
by
(conv => lhs; autodiff)

example
: (∇ (x : Fin 10 → K), ∑ i, x i)
=
fun x i => 1 :=
by
(conv => lhs; autodiff)

example
: (∇ (x : Fin 10 → K), ∑ i, ‖x i‖₂²)
=
fun x i => 2 * (x i) :=
by
(conv => lhs; autodiff)

example (A : Fin 5 → Fin 10 → K)
: (∇ (x : Fin 10 → K), fun i => ∑ j, A i j * x j)
=
fun _ dy j => ∑ i, A i j * dy i :=
by
(conv => lhs; autodiff)

variable [PlainDataType K]

example
: (∇ (x : K ^ Idx 10), fun i => x[i])
=
fun _ x => ⊞ i => x i :=
by
(conv => lhs; autodiff)

example
: (∇ (x : K ^ Idx 10), ⊞ i => x[i])
=
fun _ x => x :=
by
(conv => lhs; autodiff)

example
: (∇ (x : Fin 10 → K), ∑ i, x i)
=
fun x i => 1 :=
by
(conv => lhs; autodiff)

example
: (∇ (x : Fin 10 → K), ∑ i, ‖x i‖₂²)
=
fun x i => 2 * (x i) :=
by
(conv => lhs; autodiff)

example (A : Fin 5 → Fin 10 → K)
: (∇ (x : Fin 10 → K), fun i => ∑ j, A i j * x j)
=
fun _ dy j => ∑ i, A i j * dy i :=
by
(conv => lhs; autodiff)


68 changes: 68 additions & 0 deletions test/basic_revCDeriv.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import SciLean
import SciLean.Util.Profile
import SciLean.Tactic.LetNormalize
import SciLean.Util.RewriteBy

open SciLean

variable
{K : Type} [RealScalar K]
{X : Type} [SemiInnerProductSpace K X]
{Y : Type} [SemiInnerProductSpace K Y]
{Z : Type} [SemiInnerProductSpace K Z]
{ι : Type} [EnumType ι]
{E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)]

set_default_scalar K

example
: <∂ xy : X×Y, (xy.1,xy.2)
=
fun x => (x, fun dyz => dyz) :=
by
conv => lhs; autodiff

example
: <∂ xy : X×Y, (xy.2,xy.1)
=
fun x => ((x.snd, x.fst), fun dyz => (dyz.snd, dyz.fst)) :=
by
conv => lhs; autodiff

variable (f : Y → X → X)
(hf : HasAdjDiff K (fun yx : Y×X => f yx.1 yx.2))
(hf₁ : ∀ x, HasAdjDiff K (fun y => f y x))
(hf₂ : ∀ y, HasAdjDiff K (fun x => f y x))
(x : X)

example
: <∂ yy : Y×Y×Y, f yy.1 (f yy.2.1 (f yy.2.2 x))
=
fun x_1 =>
let zdf := <∂ (x0:=x_1.snd.snd), f x0 x;
let zdf_1 := <∂ (x0x1:=(x_1.snd.fst, zdf.fst)), f x0x1.fst x0x1.snd;
let zdf_2 := <∂ (x0x1:=(x_1.fst, zdf_1.fst)), f x0x1.fst x0x1.snd;
(zdf_2.fst, fun dz =>
let dy := Prod.snd zdf_2 dz;
let dy_1 := Prod.snd zdf_1 dy.snd;
let dy_2 := Prod.snd zdf dy_1.snd;
(dy.fst, dy_1.fst, dy_2)) :=
by
conv => lhs; autodiff

example
: <∂ yy : Y×Y×Y×Y, f yy.1 (f yy.2.1 (f yy.2.2.1 (f yy.2.2.2 x)))
=
fun x_1 =>
let zdf := <∂ (x0:=x_1.snd.snd.snd), f x0 x;
let zdf_1 := <∂ (x0x1:=(x_1.snd.snd.fst, zdf.fst)), f x0x1.fst x0x1.snd;
let zdf_2 := <∂ (x0x1:=(x_1.snd.fst, zdf_1.fst)), f x0x1.fst x0x1.snd;
let zdf_3 := <∂ (x0x1:=(x_1.fst, zdf_2.fst)), f x0x1.fst x0x1.snd;
(zdf_3.fst, fun dz =>
let dy := Prod.snd zdf_3 dz;
let dy_1 := Prod.snd zdf_2 dy.snd;
let dy_2 := Prod.snd zdf_1 dy_1.snd;
let dy_3 := Prod.snd zdf dy_2.snd;
(dy.fst, dy_1.fst, dy_2.fst, dy_3)) :=
by
conv => lhs; autodiff

0 comments on commit 80b5e5f

Please sign in to comment.