Skip to content

Commit

Permalink
better form of hmul and hpow roms of revDeriv rules
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 30, 2023
1 parent 3843717 commit e858883
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions SciLean/Core/FunctionTransformations/RevDeriv.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1274,9 +1274,11 @@ theorem HMul.hMul.arg_a0a1.revDeriv_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 * zdg.1, fun dx' =>
let dx' := dx'
(ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx')))) :=
(ydf.1 * zdg.1,
fun dx' =>
let dx₁ := (conj zdg.1 * dx')
let dx₂ := (conj ydf.1 * dx')
ydf.2 dx₁ (zdg.2 dx₂)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand All @@ -1292,14 +1294,14 @@ theorem HMul.hMul.arg_a0a1.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 * zdg.1, fun dx' dx =>
let dx' := dx'
(ydf.2 (conj zdg.1 * dx') (zdg.2 (conj ydf.1 * dx') dx))) :=
(ydf.1 * zdg.1,
fun dx' dx =>
let dx₁ := (conj zdg.1 * dx')
let dx₂ := (conj ydf.1 * dx')
ydf.2 dx₁ (zdg.2 dx₂ dx)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate; unfold revDeriv; simp; ftrans; ftrans;
simp [smul_push,add_assoc]
unfold revDerivUpdate; simp; ftrans; ftrans;
simp [smul_push,add_assoc,revDerivUpdate]

@[ftrans]
theorem HMul.hMul.arg_a0a1.revDerivProj_rule
Expand All @@ -1310,9 +1312,11 @@ theorem HMul.hMul.arg_a0a1.revDerivProj_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 * zdg.1, fun _ dy =>
let dy := dy
ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy))) :=
(ydf.1 * zdg.1,
fun _ dy =>
let dy₁ := (conj zdg.1)*dy
let dy₂ := (conj ydf.1)* dy
ydf.2 dy₁ (zdg.2 dy₂)) :=
by
unfold revDerivProj
ftrans; simp[oneHot, structMake]
Expand All @@ -1326,9 +1330,11 @@ theorem HMul.hMul.arg_a0a1.revDerivProjUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDerivUpdate K g x
(ydf.1 * zdg.1, fun _ dy dx =>
let dy := dy
ydf.2 ((conj zdg.1)*dy) (zdg.2 (conj ydf.1* dy) dx)) :=
(ydf.1 * zdg.1,
fun _ dy dx =>
let dy₁ := (conj zdg.1)*dy
let dy₂ := (conj ydf.1)*dy
ydf.2 dy₁ (zdg.2 dy₂ dx)) :=
by
unfold revDerivProjUpdate
ftrans; simp[revDerivUpdate,add_assoc]
Expand All @@ -1348,7 +1354,9 @@ theorem HSMul.hSMul.arg_a0a1.revDeriv_rule
fun x =>
let ydf := revDerivUpdate K f x
let zdg := revDeriv K g x
(ydf.1 • zdg.1, fun dx' => ydf.2 (inner zdg.1 dx') (conj ydf.1 • zdg.2 dx')) :=
(ydf.1 • zdg.1,
fun dx' =>
ydf.2 (inner zdg.1 dx') (conj ydf.1 • zdg.2 dx')) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
Expand Down Expand Up @@ -1489,7 +1497,9 @@ def HPow.hPow.arg_a0.revDeriv_rule
=
fun x =>
let ydf := revDeriv K f x
(ydf.1 ^ n, fun dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) :=
let y' := (n : K) * (conj ydf.1 ^ (n-1))
(ydf.1 ^ n,
fun dx' => ydf.2 (y' * dx')) :=
by
have ⟨_,_⟩ := hf
funext x
Expand All @@ -1502,8 +1512,9 @@ def HPow.hPow.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K f x
let y' := n * (conj ydf.1 ^ (n-1))
(ydf.1 ^ n,
fun dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) :=
fun dy dx => ydf.2 (y' * dy) dx) :=
by
unfold revDerivUpdate
funext x; ftrans; simp[mul_assoc,mul_comm,add_assoc]
Expand All @@ -1515,7 +1526,8 @@ def HPow.hPow.arg_a0.revDerivProj_rule
=
fun x =>
let ydf := revDeriv K f x
(ydf.1 ^ n, fun _ dx' => ydf.2 ((n : K) * (conj ydf.1 ^ (n-1)) * dx')) :=
let y' := (n : K) * (conj ydf.1 ^ (n-1))
(ydf.1 ^ n, fun _ dx' => ydf.2 (y' * dx')) :=
by
unfold revDerivProj; ftrans; simp[oneHot,structMake]

Expand All @@ -1526,8 +1538,9 @@ def HPow.hPow.arg_a0.revDerivProjUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K f x
let y' := n * (conj ydf.1 ^ (n-1))
(ydf.1 ^ n,
fun _ dy dx => ydf.2 (n * (conj ydf.1 ^ (n-1)) * dy) dx) :=
fun _ dy dx => ydf.2 (y' * dy) dx) :=
by
unfold revDerivProjUpdate; ftrans; simp[oneHot,structMake,revDerivUpdate]

Expand Down

0 comments on commit e858883

Please sign in to comment.