Skip to content

Commit

Permalink
fix to downstream changes to revDerivUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 22, 2023
1 parent c2e09b2 commit 06a84f3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion SciLean/Core/Monads/ForIn.lean
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ theorem ForIn.forIn.arg_bf.revCDeriv_rule_def' [Index ι] [PlainDataType X] [Pla
let mut x := xxs.1

let revPassBody := hold fun i x dw dx' =>
(revDerivUpdate K (fun (w',x') => (f w' i x').val) (w,x)).2 dx' 1 (dw,0)
(revDerivUpdate K (fun (w',x') => (f w' i x').val) (w,x)).2 dx' (dw,0)

(x,
fun dx' =>
Expand Down
6 changes: 3 additions & 3 deletions SciLean/Core/Monads/ForInStep.lean
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ theorem ForInStep.yield.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K a0 x
(.yield ydf.1, fun dy k dx => ydf.2 dy.val k dx)
(.yield ydf.1, fun dy dx => ydf.2 dy.val dx)
:= by sorry_proof


Expand Down Expand Up @@ -213,7 +213,7 @@ theorem ForInStep.done.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K a0 x
(.done ydf.1, fun dy k dx => ydf.2 dy.val k dx)
(.done ydf.1, fun dy dx => ydf.2 dy.val dx)
:= by sorry_proof

@[ftrans]
Expand Down Expand Up @@ -258,7 +258,7 @@ theorem ForInStep.val.arg_a0.revDerivUpdate_rule
=
fun x =>
let ydf := revDerivUpdate K a0 x
(ydf.1.val, fun dy k dx => ydf.2 (.yield dy) k dx)
(ydf.1.val, fun dy dx => ydf.2 (.yield dy) dx)
:= by sorry_proof

@[ftrans]
Expand Down
10 changes: 5 additions & 5 deletions SciLean/Core/Monads/MProd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ theorem MProd.mk.arg_fstsnd.revDerivUpdate_rule
let xdf' := revDerivUpdate K f w
let ydg' := revDerivUpdate K g w
(MProd.mk xdf'.1 ydg'.1,
fun dxy k dw =>
xdf'.2 dxy.1 k (ydg'.2 dxy.2 k dw)) :=
fun dxy dw =>
xdf'.2 dxy.1 (ydg'.2 dxy.2 dw)) :=
by
have ⟨_,_⟩ := hf
have ⟨_,_⟩ := hg
unfold revDerivUpdate
ftrans; funext x; simp
funext dy k dx
funext dy dx
ftrans
sorry_proof

Expand All @@ -228,7 +228,7 @@ theorem MProd.fst.arg_self.revDerivUpdate_rule
=
fun w =>
let xydxy := revDerivUpdate K f w
(xydxy.1.1, fun dx' k dw => xydxy.2 (MProd.mk dx' 0) k dw) := by sorry_proof
(xydxy.1.1, fun dx' dw => xydxy.2 (MProd.mk dx' 0) dw) := by sorry_proof


@[fprop]
Expand All @@ -252,7 +252,7 @@ theorem MProd.snd.arg_self.revDerivUpdate_rule
=
fun w =>
let xydxy := revDerivUpdate K f w
(xydxy.1.2, fun dy' k dw => xydxy.2 (MProd.mk 0 dy') k dw) := by sorry_proof
(xydxy.1.2, fun dy' dw => xydxy.2 (MProd.mk 0 dy') dw) := by sorry_proof


end OnSemiInnerProductSpace
Expand Down
18 changes: 9 additions & 9 deletions SciLean/Data/ArrayType/Properties.lean
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ theorem GetElem.getElem.arg_xs.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
(getElem ydf.1 idx dom,
fun delem (k : K) dx =>
fun delem dx =>
let dcont := introElem fun i => if i = idx then delem else 0
ydf.2 dcont k dx) :=
ydf.2 dcont dx) :=
by
have ⟨_,_⟩ := hf
unfold revDerivUpdate; ftrans; ftrans; simp
Expand All @@ -198,8 +198,8 @@ theorem GetElem.getElem.arg_xs.revDerivUpdate_rule_simple
=
fun cont =>
(getElem cont idx dom,
fun delem k dcont =>
let dcont := ArrayType.modifyElem dcont idx (fun elem => elem + k • delem)
fun delem dcont =>
let dcont := ArrayType.modifyElem dcont idx (fun elem => elem + delem)
dcont) :=
by
unfold revDerivUpdate; ftrans; sorry_proof
Expand All @@ -213,13 +213,13 @@ theorem GetElem.getElem.arg_xs_i.revDerivUpdate_rule
fun x =>
let ydf := revDerivUpdate K f x
(fun idx => getElem ydf.1 idx dom,
fun delem (k : K) dx =>
fun delem dx =>
let dcont := introElem delem
ydf.2 dcont k dx) :=
ydf.2 dcont dx) :=
by
have ⟨_,_⟩ := hf
unfold revDerivUpdate; ftrans;
funext x; simp; funext dy k dx; simp
-- funext x; simp; funext dy k dx; simp
-- ftrans -- fails to apply `semiAdjoint.pi_rule` because of some universe issues
sorry_proof

Expand Down Expand Up @@ -438,9 +438,9 @@ theorem SetElem.setElem.arg_contelem.revDerivUpdate_rule
let cdc := revDerivUpdate K cont x
let ede := revDerivUpdate K elem x
(setElem cdc.1 idx ede.1,
fun dcont' k dx =>
fun dcont' dx =>
let delem' := dcont'[idx]
ede.2 delem' k (cdc.2 (setElem dcont' idx 0) k dx)
ede.2 delem' (cdc.2 (setElem dcont' idx 0) dx)
) :=
by
have ⟨_,_⟩ := hcont
Expand Down

0 comments on commit 06a84f3

Please sign in to comment.