From 4dfd8229e5431d24b1d3edf8887ed2626db2a000 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Wed, 15 Nov 2023 17:18:22 -0500 Subject: [PATCH] some work on differentiating for loops --- SciLean/Core/Function.lean | 52 ++++--- SciLean/Core/Monads/ForIn.lean | 194 ++++++++++++++++++------- SciLean/Core/Monads/ForInTests.lean | 56 ++++++- SciLean/Core/Monads/Id.lean | 65 +++++++++ SciLean/Core/Monads/RevDerivMonad.lean | 2 +- 5 files changed, 294 insertions(+), 75 deletions(-) diff --git a/SciLean/Core/Function.lean b/SciLean/Core/Function.lean index 4c110b50..71e26f7d 100644 --- a/SciLean/Core/Function.lean +++ b/SciLean/Core/Function.lean @@ -226,6 +226,34 @@ variable {W : Type _} [SemiInnerProductSpace K W] +/-- Reverse pass of a for loop + + See `Forin.arg_bf.revCDeriv_rule_dataArrayImpl` how it relates to reverse derivative of for loop. + + -- WARNING: `dx'` and `dw` behave differently + - `df'` computes gradient in `dx'` + - `df'` computes update to gradient in `dw` + + The value of `df'` should be: + df' = fun w i x dx' dw => + ((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x)) + + -/ +def ForIn.arg_bf.revPass_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W] + (df' : W → ι → X → X → W → X×W) (w : W) (xs : DataArray X) (dx' : X) : X×W := Id.run do + let n := Index.size ι + let mut dx' := dx' + let mut dw : W := 0 + for i in [0:n.toNat] do + let j' : Idx n := ⟨n-i.toUSize-1,sorry_proof⟩ + let j : ι := fromIdx j' + let xj := xs.get ⟨j'.1, sorry_proof⟩ + let dxw' := df' w j xj dx' dw + dx' := dxw'.1 + dw := dxw'.2 + (dx',dw) + + /-- -- WARNING: `dx'` and `dw` behave differently - `df'` computes gradient in `dx'` @@ -238,30 +266,16 @@ variable def ForIn.arg_bf.revDeriv_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W] (init : X) (f : W → ι → X → X) (df' : W → ι → X → X → W → X×W) (w : W) : X×(X→X×W) := - let n := Index.size ι - if n = 0 then - (init, fun _ => 0) - else Id.run do + Id.run do + let n := Index.size ι -- forward pass let mut xs : DataArray X := .mkEmpty n let mut x := init for i in fullRange ι do xs := xs.push x x := f w i x - (x, fun dx' => Id.run do - -- backward pass - -- TODO: implemente reverse ranges! - let mut dx' := dx' - let mut dw : W := 0 - for i in [0:n.toNat] do - let j' : Idx n := ⟨n-i.toUSize-1,sorry_proof⟩ - let j : ι := fromIdx j' - let xj := xs.get ⟨j'.1, sorry_proof⟩ - let (dx'',dw') := df' w j xj dx' dw - dx' := dx'' - dw := dw' - (dx',dw)) - + (x, fun dx' => ForIn.arg_bf.revPass_dataArrayImpl df' w xs dx') + @[ftrans] theorem ForIn.forIn.arg_bf.revDerivM_rule' [Index ι] [PlainDataType X] [PlainDataType W] @@ -480,3 +494,5 @@ theorem Function.reduceD.arg_fdefault.revCDeriv_rule end OnSemiInnerProductSpace + + diff --git a/SciLean/Core/Monads/ForIn.lean b/SciLean/Core/Monads/ForIn.lean index b3a30d85..3edd0f73 100644 --- a/SciLean/Core/Monads/ForIn.lean +++ b/SciLean/Core/Monads/ForIn.lean @@ -1,5 +1,6 @@ import SciLean.Core.Monads.FwdDerivMonad import SciLean.Core.Monads.Id +import SciLean.Data.DataArray set_option linter.unusedVariables false @@ -8,18 +9,16 @@ namespace SciLean variable {K : Type _} [IsROrC K] --- This is not true but lets assume it for now until I have +-- This is not true but lets assume it for now instance [Vec K X] : Vec K (ForInStep X) := sorry --- This is not true but lets assume it for now until I have +-- This is not true but lets assume it for now instance [SemiInnerProductSpace K X] : SemiInnerProductSpace K (ForInStep X) := sorry end SciLean open SciLean --- set_option linter.unusedVariables false - def ForInStep.val : ForInStep α → α | .yield a => a | .done a => a @@ -60,6 +59,9 @@ variable {Y : Type _} [Vec K Y] {Z : Type _} [Vec K Z] +-------------------------------------------------------------------------------- +-- ForIn.forIn ----------------------------------------------------------------- +-------------------------------------------------------------------------------- -- we need some kind of lawful version of `ForIn` to be able to prove this @[fprop] @@ -101,6 +103,9 @@ by simp [forIn,Std.Range.forIn,Std.Range.forIn.loop,Std.Range.forIn.loop.match_1] ftrans +-------------------------------------------------------------------------------- +-- ForInStep.yield ------------------------------------------------------------- +-------------------------------------------------------------------------------- @[fprop] theorem ForInStep.yield.arg_a0.IsDifferentiable_rule @@ -133,6 +138,11 @@ theorem ForInStep.done.arg_a0.IsDifferentiable_rule (a0 : X → Y) (ha0 : IsDifferentiable K a0) : IsDifferentiable K fun x => ForInStep.done (a0 x) := by sorry_proof + +-------------------------------------------------------------------------------- +-- ForInStep.done ------------------------------------------------------------- +-------------------------------------------------------------------------------- + -- this is a hack as with Id monad sometimes you do not need `pure` which trips `fprop` up @[fprop] theorem ForInStep.done.arg_a0.IsDifferentiableM_rule @@ -157,6 +167,8 @@ theorem ForInStep.done.arg_a0.fwdCDeriv_rule end OnVec +-------------------------------------------------------------------------------- + section OnSemiInnerProductSpace @@ -168,8 +180,13 @@ variable {X : Type _} [SemiInnerProductSpace K X] {Y : Type _} [SemiInnerProductSpace K Y] {Z : Type _} [SemiInnerProductSpace K Z] + {W : Type _} [SemiInnerProductSpace K W] +-------------------------------------------------------------------------------- +-- ForIn.forIn ----------------------------------------------------------------- +-------------------------------------------------------------------------------- + -- we need some kind of lawful version of `ForIn` to be able to prove this @[fprop] theorem ForIn.forIn.arg_bf.HasAdjDiffM_rule @@ -177,10 +194,18 @@ theorem ForIn.forIn.arg_bf.HasAdjDiffM_rule (hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2)) : HasAdjDiffM K (fun x => forIn range (init x) (f x)) := sorry_proof +@[fprop] +theorem ForIn.forIn.arg_bf.HasAdjDiff_rule [ForIn Id ρ α] + (range : ρ) (init : X → Y) (f : X → α → Y → (ForInStep Y)) + (hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiff K (fun (xy : X×Y) => f xy.1 a xy.2)) + : HasAdjDiff K (fun x => forIn (m:=Id) range (init x) (f x)) := sorry_proof + --- we need some kind of lawful version of `ForIn` to be able to prove this -@[ftrans] -theorem ForIn.forIn.arg_bf.revDerivM_rule +/-- This version of reverse derivative of a for loop builds a big lambda function + for the reverse pass during the forward pass. This is probably ok in for loops + where each iteration is very costly and there are not many iterations. +-/ +theorem ForIn.forIn.arg_bf.revDerivM_rule_lazy (range : ρ) (init : X → Y) (f : X → α → Y → m (ForInStep Y)) (hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2)) : revDerivM K (fun x => forIn range (init x) (f x)) @@ -205,55 +230,117 @@ theorem ForIn.forIn.arg_bf.revDerivM_rule by sorry_proof -theorem ForIn.forIn.arg_bf.revDerivM_rule_alternative [ForIn Id ρ α] - (init : X → Y) (f : X → Nat → Y → Y) - (hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiff K (fun (xy : X×Y) => f xy.1 a xy.2)) - : revDerivM K (fun x => forIn (m:=Id) (Std.Range.mk start stop step) (init x) (fun i y => .yield (f x i y))) - = - (fun x => Id.run do - let (y₀,dinit') := revCDeriv K init x - let (y,ys) ← forIn (Std.Range.mk start stop step) (y₀,#[]) (fun i (y,ys) => - let y' := f x i y - .yield (y', ys.push y')) - pure (y, - fun dy => do - let (dx,dy) ← forIn (Std.Range.mk start stop step) ((0:X),dy) (fun i (dx,dy) => do - let j := stop - ((i-start) + 1) - let yᵢ : Y := ys[j]! - let (dx',dy) := (revCDeriv K (fun (xy : X×Y) => f xy.1 i xy.2) (x,yᵢ)).2 dy - .yield (dx + dx', dy)) - pure (dx + dinit' dy))) := -by - sorry_proof +/-- Reverse pass of a for loop implemented using `DataArray` + + See `Forin.arg_bf.revCDeriv_rule_dataArrayImpl` how it relates to reverse derivative of for loop. + + TODO: Index shoud support iterating in reverse order + + WARNING: `dx'` and `dw` behave differently + - `df'` computes gradient in `dx'` + - `df'` computes update to gradient in `dw` + The value of `df'` should be: + df' = fun w i x dx' dw => + ((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x)) + -/ +def ForIn.arg_bf.revPass_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W] + (df' : W → ι → X → X → W → X×W) (w : W) (xs : DataArrayN X ι) (dx' : X) : X×W := Id.run do + let n := Index.size ι + let mut dx' := dx' + let mut dw : W := 0 + for i in [0:n.toNat] do + let j : ι := fromIdx ⟨n-i.toUSize-1,sorry_proof⟩ + let xj := xs[j] + let dxw' := df' w j xj dx' dw + dx' := dxw'.1 + dw := dxw'.2 + (dx',dw) + + +/-- Reverse derivative of a for loop + + WARNING: `dx'` and `dw` behave differently + - `df'` computes gradient in `dx'` + - `df'` computes update to gradient in `dw` + + The value of `df'` should be: + df' = fun w i x dx' dw => + ((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x)) +-/ +def ForIn.arg_bf.revDeriv_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W] + (init : X) (f : W → ι → X → X) (df' : W → ι → X → X → W → X×W) (w : W) + : X×(X→X×W) := + Id.run do + let n := Index.size ι + -- forward pass + let mut xs : DataArray X := .mkEmpty n + let mut x := init + for i in fullRange ι do + xs := xs.push x + x := f w i x + let xs' : DataArrayN X ι := ⟨xs, sorry_proof⟩ + (x, fun dx' => ForIn.arg_bf.revPass_dataArrayImpl df' w xs' dx') + + +/-- The do notation leaves the for loop body in a strange form `do pure PUnit.unit; pure <| ForInStep.yield (f w i y))` + Marking this theorem with `ftrans` is a bit of a hack. It normalizes the body to `ForInStep.yield (f w i y)`. + -/ +@[ftrans] +theorem ForIn.forIn.arg_bf.revDerivM_rule_normalization [Index ι] + (init : W → X) (f : W → ι → X → X) + : revDerivM K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => do pure PUnit.unit; pure <| ForInStep.yield (f w i y))) + = + revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y))) := by rfl --- Proof that the above theorem is true for the range [0:3] and function that does not break the for loop -example - (init : X → Y) (f : X → Nat → Y → m Y) - (hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2)) - : revDerivM K (fun x => forIn [0:3] (init x) (fun i y => do pure (ForInStep.yield (← f x i y)))) - = - (fun x => do - let ydinit := revCDeriv K init x - let ydf ← forIn [0:3] (ydinit.1, fun (dy:Y) => pure (f:=m') ((0:X),dy)) - (fun a ydf => do - let ydf' ← revDerivM K (fun (xy : X×Y) => f xy.1 a xy.2) (x,ydf.1) - let df : Y → m' (X×Y) := - fun dy : Y => do - let dxy ← ydf'.2 dy - let dxy' ← ydf.2 dxy.2 - pure (dxy.1 + dxy'.1, dxy'.2) - pure (ForInStep.yield (ydf'.1, df))) - pure (ydf.1, - fun dy => do - let dxy ← ydf.2 dy - pure (dxy.1 + ydinit.2 dxy.2))) := + + +/-- The do notation leaves the for loop body in a strange form `do pure PUnit.unit; pure <| ForInStep.yield (f w i y))` + Marking this theorem with `ftrans` is a bit of a hack. It normalizes the body to `ForInStep.yield (f w i y)`. + -/ +@[ftrans] +theorem ForIn.forIn.arg_bf.revCDeriv_rule_normalization [Index ι] + (init : W → X) (f : W → ι → X → X) + : revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => do pure PUnit.unit; pure <| ForInStep.yield (f w i y))) + = + revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y))) := by rfl + + +@[ftrans] +theorem ForIn.forIn.arg_bf.revCDeriv_rule_def [Index ι] [PlainDataType X] [PlainDataType W] + (init : W → X) (f : W → ι → X → X) + (hinit : HasAdjDiff K init) (hf : ∀ i, HasAdjDiff K (fun (w,x) => f w i x)) + : revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y))) + = + fun w => (Id.run do + let n := Index.size ι + let initdinit := revCDeriv K init w + + -- forward pass + let mut xs : DataArray X := .mkEmpty n + let mut x := initdinit.1 + for i in fullRange ι do + xs := xs.push x + x := f w i x + let xs' : DataArrayN X ι := ⟨xs, sorry_proof⟩ + + let revPassBody := hold fun w i x dx' dw => + let dwx' := gradient K (fun (w',x') => f w' i x') (w,x) dx' + (dwx'.2, dw + dwx'.1) + + (x, + fun dx' => + -- reverse pass + let dxw' := ForIn.arg_bf.revPass_dataArrayImpl revPassBody w xs' dx' + initdinit.2 dxw'.1 + dxw'.2)) := by - simp [revCDeriv,forIn,Std.Range.forIn,Std.Range.forIn.loop,Std.Range.forIn.loop.match_1, revCDeriv] - ftrans - simp[add_assoc, revCDeriv] + sorry_proof + +-------------------------------------------------------------------------------- +-- ForInStep.yield ------------------------------------------------------------- +-------------------------------------------------------------------------------- @[fprop] theorem ForInStep.yield.arg_a0.HasAdjDiff_rule @@ -315,4 +402,9 @@ theorem ForInStep.done.arg_a0.revDerivM_rule (.done ydf.1, fun y => ydf.2 y.val) := by sorry_proof + end OnSemiInnerProductSpace + + + + diff --git a/SciLean/Core/Monads/ForInTests.lean b/SciLean/Core/Monads/ForInTests.lean index 52d07f32..5b31ec99 100644 --- a/SciLean/Core/Monads/ForInTests.lean +++ b/SciLean/Core/Monads/ForInTests.lean @@ -1,5 +1,7 @@ import SciLean.Core.Monads.ForIn import SciLean.Tactic.LetNormalize +import SciLean.Core.FloatAsReal +import SciLean.Core.Notation open SciLean @@ -13,6 +15,7 @@ variable {Y : Type _} [Vec K Y] {Z : Type _} [Vec K Z] + -- set_option pp.notation false in example @@ -70,13 +73,56 @@ example : fwdDerivM K (fun x : K => show m K from do pure ydy) := by - (conv => lhs; ftrans only; let_normalize; ftrans only; simp (config := {zeta := false})) - simp - funext x dx - congr - funext a (y,dy) + (conv => lhs; ftrans; ftrans; simp (config := {zeta := false})) simp +-- @[ftrans_simp] +-- theorem revDerivM_eq_revCDeriv_on_Id' +-- [SemiInnerProductSpace K X] [SemiInnerProductSpace K Y] (x : X) (f : X → Y) +-- : revDerivM K fun f = revCDeriv K f := by set_option pp.all true in rfl + + +#eval + ((gradient Float (fun x : Float ^ Idx 3 => Id.run do + let mut y := 1.0 + for i in fullRange (Idx 3) do + y := y * x[i] + y)) + rewrite_by + unfold gradient + ftrans + ftrans + ftrans + unfold gradient + ftrans) + ⊞[5.0,6,7] 1 + +variable (y y' : Id Float × (Id Float → Float)) (b' : Id Float × (Id Float → Float) → Float) (b : Float × (Float → Float) → Float) (q : y' = y) (h : ∀ (x : Id Float × (Id Float → Float)), x = y → let z : Float × (Float → Float) := x; (b z = b' z)) + +#check Tactic.LSimp.let_congr_eq (α := Float × (Float → Float)) q h + +set_option trace.Meta.Tactic.simp.discharge true in +set_option trace.Meta.Tactic.fprop.discharge true in +-- set_option trace.Meta.Tactic.fprop.step true in +set_option trace.Meta.Tactic.ftrans.step true in +set_option pp.notation false in +#eval + ((gradient Float (fun x : Float ^ Idx 10 => Id.run do + let mut s := x[(⟨0,sorry⟩ : Idx 10)] + for i in [0:9] do + let i : Idx 10 := ⟨i.toUSize+1,sorry⟩ + s := s + x[i] + s)) + rewrite_by + unfold gradient + ftrans + ftrans + ftrans + unfold gradient + ftrans) + ⊞[5.0,6,7,8,9,10,11,12,13,14] 1 + + -- example : fwdDerivM K (fun x : K => show m K from do -- let mut y := x -- for i in [0:5] do diff --git a/SciLean/Core/Monads/Id.lean b/SciLean/Core/Monads/Id.lean index e882a8ea..921df4e9 100644 --- a/SciLean/Core/Monads/Id.lean +++ b/SciLean/Core/Monads/Id.lean @@ -89,4 +89,69 @@ theorem Id.run.arg_x.revCDeriv_rule (a : X → Id Y) = fun x => (revDerivM K a x) := by rfl + +-- some normalizations for Id monad because it is pain in the ass to work with +-- as one can often abuse defEq + +@[ftrans_simp] +theorem revDerivM_eq_revCDeriv_on_Id (f : X → Y) + : revDerivM (m:=Id) K f = fun x => pure (revCDeriv K f x) := by rfl + +@[ftrans_simp] +theorem revDerivM_eq_revCDeriv_on_Id' (f : X → Id Y) + : revDerivM K f = revCDeriv K f := by set_option pp.all true in rfl + +@[fprop] +theorem Pure.pure.arg_a0.HasAdjDiff_rule + (a0 : X → Y) + (ha0 : HasAdjDiff K a0) + : HasAdjDiff K (fun x => Pure.pure (f:=Id) (a0 x)) := +by + simp[Pure.pure]; fprop + +@[fprop] +theorem Bind.bind.arg_a0a1.HasAdjDiff_rule_on_Id + (a0 : X → Y) (a1 : X → Y → Z) + (ha0 : HasAdjDiff K a0) (ha1 : HasAdjDiff K (fun (x,y) => a1 x y)) + : HasAdjDiff K (fun x => Bind.bind (m:=Id) (a0 x) (a1 x)) := by simp[Bind.bind]; fprop + + +@[ftrans] +theorem Bind.bind.arg_a0a1.revDerivM_rule_on_Id + (a0 : X → Y) (a1 : X → Y → Z) + (ha0 : HasAdjDiff K a0) (ha1 : HasAdjDiff K (fun (x,y) => a1 x y)) + : (revDerivM (m:=Id) K (fun x => Bind.bind (a0 x) (a1 x))) + = + fun x => + let ydg' := revCDeriv K a0 x + let zdf' := revCDeriv K (fun (x,y) => a1 x y) (x,ydg'.1) + (zdf'.1, + fun dz' => + let dxy' := zdf'.2 dz' + let dx' := ydg'.2 dxy'.2 + dxy'.1 + dx') := +by + simp[revDerivM]; ftrans + +-- @[ftrans] +-- This theorem causes some downstream issue in simp when applying congruence lemmas +-- The issue seems that there is some defEq abuse that stop working +theorem Bind.bind.arg_a0a1.revCDeriv_rule_on_Id + (a0 : X → Y) (a1 : X → Y → Z) + (ha0 : HasAdjDiff K a0) (ha1 : HasAdjDiff K (fun (x,y) => a1 x y)) + : (revCDeriv K (fun x => Bind.bind (m:=Id) (a0 x) (a1 x))) + = + fun x => + let ydg' := revCDeriv K a0 x + let zdf' := revCDeriv K (fun (x,y) => a1 x y) (x,ydg'.1) + (zdf'.1, + fun dz' => + let dxy' := zdf'.2 dz' + let dx' := ydg'.2 dxy'.2 + dxy'.1 + dx') := +by + simp (config := {zeta:=false}) [Bind.bind]; ftrans; rfl + + + end OnSemiInnerProductSpace diff --git a/SciLean/Core/Monads/RevDerivMonad.lean b/SciLean/Core/Monads/RevDerivMonad.lean index 5bea42ee..39a64561 100644 --- a/SciLean/Core/Monads/RevDerivMonad.lean +++ b/SciLean/Core/Monads/RevDerivMonad.lean @@ -607,7 +607,7 @@ by simp [RevDerivMonad.revDerivM_pair a0 ha0] - +-------------------------------------------------------------------------------- -- d/ite ----------------------------------------------------------------------- --------------------------------------------------------------------------------