Skip to content

Commit

Permalink
more test for complicated derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Aug 29, 2023
1 parent 20b78a9 commit 34f6260
Showing 1 changed file with 87 additions and 53 deletions.
140 changes: 87 additions & 53 deletions SciLean/DoodleRevCDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ import SciLean.Util.RewriteBy
open SciLean

variable
{K : Type} [IsROrC K]
{K : Type} [RealScalar K]
{X : Type} [SemiInnerProductSpace K X]
{Y : Type} [SemiInnerProductSpace K Y]
{Z : Type} [SemiInnerProductSpace K Z]
{ι : Type} [Fintype ι]
{ι : Type} [EnumType ι]
{E : ι → Type _} [∀ i, SemiInnerProductSpace K (E i)]


Expand All @@ -26,6 +26,21 @@ theorem cderiv.uncurry_rule (f : X → Y → Z)
cderiv K (fun y => f xy.1 y) xy.2 dxy.2 := sorry_proof

theorem revCDeriv.prod_rule
(f : X → Y → Z) (hf : HasAdjDiff K (fun x : X×Y => f x.1 x.2))
: revCDeriv K (fun xy : X × Y => f xy.1 xy.2)
=
fun x =>
let ydf := revCDeriv K (fun x : X×Y => f x.1 x.2) (x.1,x.2)
(ydf.1,
fun dz =>
let dxy := ydf.2 dz
(dxy.1, dxy.2)) :=
by
-- simp [cderiv.uncurry_rule _ hf]
sorry_proof


theorem revCDeriv.prod_rule'
(f : X → Y → X×Y → Z) (hf : HasAdjDiff K (fun x : X×Y×(X×Y) => f x.1 x.2.1 x.2.2))
: revCDeriv K (fun xy : X × Y => f xy.1 xy.2 xy)
=
Expand All @@ -39,7 +54,7 @@ by
-- simp [cderiv.uncurry_rule _ hf]
sorry_proof

open BigOperators

theorem revCDeriv.pi_rule_v1
(f : (i : ι) → X → (ι → X) → Y) (hf : ∀ i, HasAdjDiff K (fun x : X×(ι→X) => f i x.1 x.2))
: (revCDeriv K fun (x : ι → X) (i : ι) => f i (x i) x)
Expand All @@ -65,8 +80,6 @@ by
simp;
set_option trace.Meta.Tactic.simp.discharge true in
ftrans only
rw[semiAdjoint.pi_rule _ _ sorry_proof]
simp
ftrans


Expand Down Expand Up @@ -100,65 +113,86 @@ by


example
: revCDeriv K (fun xy : X×Y => (xy.1,xy.2))
: <∂ xy : X×Y, (xy.1,xy.2)
=
fun xy =>
(xy, fun dxy => dxy) :=
fun x => (x, fun dyz => dyz) :=
by
conv =>
lhs; autodiff; enter[x]; let_normalize

conv => lhs; autodiff

example
: revCDeriv K (fun xy : X×Y => (xy.2,xy.1))
: <∂ xy : X×Y, (xy.2,xy.1)
=
fun xy =>
((xy.2,xy.1), fun dxy => (dxy.2, dxy.1)) :=
fun x => ((x.snd, x.fst), fun dyz => (dyz.snd, dyz.fst)) :=
by
conv =>
lhs; autodiff; enter[x]; let_normalize

#eval 0

#check
(∇ (x : Fin 10 → K), fun i => x i)
rewrite_by
unfold gradient
rw[revCDeriv.pi_rule_v1 (K:=K) (fun i x _ => x) (by fprop)]
symdiff
-- rw[revCDeriv.pi_rule _ _ (by fprop)]
-- ftrans
-- simp
conv => lhs; autodiff


variable (f : Y → X → X)
(hf : HasAdjDiff K (fun yx : Y×X => f yx.1 yx.2))
(hf₁ : ∀ x, HasAdjDiff K (fun y => f y x))
(hf₂ : ∀ y, HasAdjDiff K (fun x => f y x))
(x : X)

/--
info: fun x_1 =>
let zdf := <∂ (x0:=x_1.snd.snd), f x0 x;
let zdf_1 := <∂ (x0x1:=(x_1.snd.fst, zdf.fst)), f x0x1.fst x0x1.snd;
let zdf_2 := <∂ (x0x1:=(x_1.fst, zdf_1.fst)), f x0x1.fst x0x1.snd;
(zdf_2.fst, fun dz =>
let dy := Prod.snd zdf_2 dz;
let dy_1 := Prod.snd zdf_1 dy.snd;
let dy_2 := Prod.snd zdf dy_1.snd;
(dy.fst, dy_1.fst, dy_2)) : Y × Y × Y → X × (X → Y × Y × Y)
-/
#guard_msgs in
#check
<∂ yy : Y×Y×Y, f yy.1 (f yy.2.1 (f yy.2.2 x))
rewrite_by autodiff

-- ftrans

example
: <∂ yy : Y×Y×Y×Y, f yy.1 (f yy.2.1 (f yy.2.2.1 (f yy.2.2.2 x)))
=
fun x_1 =>
let zdf := <∂ (x0:=x_1.snd.snd.snd), f x0 x;
let zdf_1 := <∂ (x0x1:=(x_1.snd.snd.fst, zdf.fst)), f x0x1.fst x0x1.snd;
let zdf_2 := <∂ (x0x1:=(x_1.snd.fst, zdf_1.fst)), f x0x1.fst x0x1.snd;
let zdf_3 := <∂ (x0x1:=(x_1.fst, zdf_2.fst)), f x0x1.fst x0x1.snd;
(zdf_3.fst, fun dz =>
let dy := Prod.snd zdf_3 dz;
let dy_1 := Prod.snd zdf_2 dy.snd;
let dy_2 := Prod.snd zdf_1 dy_1.snd;
let dy_3 := Prod.snd zdf dy_2.snd;
(dy.fst, dy_1.fst, dy_2.fst, dy_3)) :=
by
conv => lhs; autodiff

#exit

example : ∀ (i : SciLean.Idx n), SciLean.HasAdjDiff Float (fun x : Float^Idx n => ‖x[i]‖₂²) := by fprop
example
: (∇ (x : Fin 10 → K), fun i => x i)
=
fun x dx => dx :=
by
autodiff; admit

set_option trace.Meta.Tactic.ftrans.step true in
set_option trace.Meta.Tactic.simp.rewrite true in
set_option trace.Meta.Tactic.simp.discharge true in
#check
∇ (x : Idx n → Float), (fun i => x i)
rewrite_by
symdiff
example
: (∇ (x : Fin 10 → K), ∑ i, x i)
=
fun x i => 1 :=
by
autodiff; admit

set_option trace.Meta.Tactic.ftrans.step true in
set_option trace.Meta.Tactic.fprop.step true in
#check
(<∂ x : Idx n → Float, fun i => x i)
rewrite_by
rw [revCDeriv.pi_rule _ _ (by fprop)]
ftrans; simp
example
: (∇ (x : Fin 10 → K), ∑ i, ‖x i‖₂²)
=
fun x => 2 • x :=
by
autodiff; admit

example (A : Fin 5 → Fin 10 → K)
: (∇ (x : Fin 10 → K), fun i => ∑ j, A i j * x j)
=
fun _ dy j => ∑ i, A i j * dy i :=
by
autodiff; admit

set_option trace.Meta.Tactic.simp.rewrite true in
set_option trace.Meta.Tactic.simp.discharge true in
set_option trace.Meta.Tactic.ftrans.step true in
#check
∇ (x : Idx n → Float), ∑ i, x i
rewrite_by
symdiff

0 comments on commit 34f6260

Please sign in to comment.