Skip to content

Commit

Permalink
some work on differentiating for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 15, 2023
1 parent 3f1521f commit 4dfd822
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 75 deletions.
52 changes: 34 additions & 18 deletions SciLean/Core/Function.lean
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,34 @@ variable
{W : Type _} [SemiInnerProductSpace K W]


/-- Reverse pass of a for loop
See `Forin.arg_bf.revCDeriv_rule_dataArrayImpl` how it relates to reverse derivative of for loop.
-- WARNING: `dx'` and `dw` behave differently
- `df'` computes gradient in `dx'`
- `df'` computes update to gradient in `dw`
The value of `df'` should be:
df' = fun w i x dx' dw =>
((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x))
-/
def ForIn.arg_bf.revPass_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W]
(df' : W → ι → X → X → W → X×W) (w : W) (xs : DataArray X) (dx' : X) : X×W := Id.run do
let n := Index.size ι
let mut dx' := dx'
let mut dw : W := 0
for i in [0:n.toNat] do
let j' : Idx n := ⟨n-i.toUSize-1,sorry_proof⟩
let j : ι := fromIdx j'
let xj := xs.get ⟨j'.1, sorry_proof⟩
let dxw' := df' w j xj dx' dw
dx' := dxw'.1
dw := dxw'.2
(dx',dw)


/--
-- WARNING: `dx'` and `dw` behave differently
- `df'` computes gradient in `dx'`
Expand All @@ -238,30 +266,16 @@ variable
def ForIn.arg_bf.revDeriv_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W]
(init : X) (f : W → ι → X → X) (df' : W → ι → X → X → W → X×W) (w : W)
: X×(X→X×W) :=
let n := Index.size ι
if n = 0 then
(init, fun _ => 0)
else Id.run do
Id.run do
let n := Index.size ι
-- forward pass
let mut xs : DataArray X := .mkEmpty n
let mut x := init
for i in fullRange ι do
xs := xs.push x
x := f w i x
(x, fun dx' => Id.run do
-- backward pass
-- TODO: implemente reverse ranges!
let mut dx' := dx'
let mut dw : W := 0
for i in [0:n.toNat] do
let j' : Idx n := ⟨n-i.toUSize-1,sorry_proof⟩
let j : ι := fromIdx j'
let xj := xs.get ⟨j'.1, sorry_proof⟩
let (dx'',dw') := df' w j xj dx' dw
dx' := dx''
dw := dw'
(dx',dw))

(x, fun dx' => ForIn.arg_bf.revPass_dataArrayImpl df' w xs dx')


@[ftrans]
theorem ForIn.forIn.arg_bf.revDerivM_rule' [Index ι] [PlainDataType X] [PlainDataType W]
Expand Down Expand Up @@ -480,3 +494,5 @@ theorem Function.reduceD.arg_fdefault.revCDeriv_rule


end OnSemiInnerProductSpace


194 changes: 143 additions & 51 deletions SciLean/Core/Monads/ForIn.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import SciLean.Core.Monads.FwdDerivMonad
import SciLean.Core.Monads.Id
import SciLean.Data.DataArray

set_option linter.unusedVariables false

Expand All @@ -8,18 +9,16 @@ namespace SciLean
variable
{K : Type _} [IsROrC K]

-- This is not true but lets assume it for now until I have
-- This is not true but lets assume it for now
instance [Vec K X] : Vec K (ForInStep X) := sorry

-- This is not true but lets assume it for now until I have
-- This is not true but lets assume it for now
instance [SemiInnerProductSpace K X] : SemiInnerProductSpace K (ForInStep X) := sorry


end SciLean
open SciLean

-- set_option linter.unusedVariables false

def ForInStep.val : ForInStep α → α
| .yield a => a
| .done a => a
Expand Down Expand Up @@ -60,6 +59,9 @@ variable
{Y : Type _} [Vec K Y]
{Z : Type _} [Vec K Z]

--------------------------------------------------------------------------------
-- ForIn.forIn -----------------------------------------------------------------
--------------------------------------------------------------------------------

-- we need some kind of lawful version of `ForIn` to be able to prove this
@[fprop]
Expand Down Expand Up @@ -101,6 +103,9 @@ by
simp [forIn,Std.Range.forIn,Std.Range.forIn.loop,Std.Range.forIn.loop.match_1]
ftrans

--------------------------------------------------------------------------------
-- ForInStep.yield -------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.yield.arg_a0.IsDifferentiable_rule
Expand Down Expand Up @@ -133,6 +138,11 @@ theorem ForInStep.done.arg_a0.IsDifferentiable_rule
(a0 : X → Y) (ha0 : IsDifferentiable K a0)
: IsDifferentiable K fun x => ForInStep.done (a0 x) := by sorry_proof


--------------------------------------------------------------------------------
-- ForInStep.done -------------------------------------------------------------
--------------------------------------------------------------------------------

-- this is a hack as with Id monad sometimes you do not need `pure` which trips `fprop` up
@[fprop]
theorem ForInStep.done.arg_a0.IsDifferentiableM_rule
Expand All @@ -157,6 +167,8 @@ theorem ForInStep.done.arg_a0.fwdCDeriv_rule
end OnVec


--------------------------------------------------------------------------------


section OnSemiInnerProductSpace

Expand All @@ -168,19 +180,32 @@ variable
{X : Type _} [SemiInnerProductSpace K X]
{Y : Type _} [SemiInnerProductSpace K Y]
{Z : Type _} [SemiInnerProductSpace K Z]
{W : Type _} [SemiInnerProductSpace K W]


--------------------------------------------------------------------------------
-- ForIn.forIn -----------------------------------------------------------------
--------------------------------------------------------------------------------

-- we need some kind of lawful version of `ForIn` to be able to prove this
@[fprop]
theorem ForIn.forIn.arg_bf.HasAdjDiffM_rule
(range : ρ) (init : X → Y) (f : X → α → Y → m (ForInStep Y))
(hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2))
: HasAdjDiffM K (fun x => forIn range (init x) (f x)) := sorry_proof

@[fprop]
theorem ForIn.forIn.arg_bf.HasAdjDiff_rule [ForIn Id ρ α]
(range : ρ) (init : X → Y) (f : X → α → Y → (ForInStep Y))
(hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiff K (fun (xy : X×Y) => f xy.1 a xy.2))
: HasAdjDiff K (fun x => forIn (m:=Id) range (init x) (f x)) := sorry_proof


-- we need some kind of lawful version of `ForIn` to be able to prove this
@[ftrans]
theorem ForIn.forIn.arg_bf.revDerivM_rule
/-- This version of reverse derivative of a for loop builds a big lambda function
for the reverse pass during the forward pass. This is probably ok in for loops
where each iteration is very costly and there are not many iterations.
-/
theorem ForIn.forIn.arg_bf.revDerivM_rule_lazy
(range : ρ) (init : X → Y) (f : X → α → Y → m (ForInStep Y))
(hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2))
: revDerivM K (fun x => forIn range (init x) (f x))
Expand All @@ -205,55 +230,117 @@ theorem ForIn.forIn.arg_bf.revDerivM_rule
by
sorry_proof

theorem ForIn.forIn.arg_bf.revDerivM_rule_alternative [ForIn Id ρ α]
(init : X → Y) (f : X → Nat → Y → Y)
(hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiff K (fun (xy : X×Y) => f xy.1 a xy.2))
: revDerivM K (fun x => forIn (m:=Id) (Std.Range.mk start stop step) (init x) (fun i y => .yield (f x i y)))
=
(fun x => Id.run do
let (y₀,dinit') := revCDeriv K init x
let (y,ys) ← forIn (Std.Range.mk start stop step) (y₀,#[]) (fun i (y,ys) =>
let y' := f x i y
.yield (y', ys.push y'))
pure (y,
fun dy => do
let (dx,dy) ← forIn (Std.Range.mk start stop step) ((0:X),dy) (fun i (dx,dy) => do
let j := stop - ((i-start) + 1)
let yᵢ : Y := ys[j]!
let (dx',dy) := (revCDeriv K (fun (xy : X×Y) => f xy.1 i xy.2) (x,yᵢ)).2 dy
.yield (dx + dx', dy))
pure (dx + dinit' dy))) :=
by
sorry_proof

/-- Reverse pass of a for loop implemented using `DataArray`
See `Forin.arg_bf.revCDeriv_rule_dataArrayImpl` how it relates to reverse derivative of for loop.
TODO: Index shoud support iterating in reverse order
WARNING: `dx'` and `dw` behave differently
- `df'` computes gradient in `dx'`
- `df'` computes update to gradient in `dw`
The value of `df'` should be:
df' = fun w i x dx' dw =>
((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x))
-/
def ForIn.arg_bf.revPass_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W]
(df' : W → ι → X → X → W → X×W) (w : W) (xs : DataArrayN X ι) (dx' : X) : X×W := Id.run do
let n := Index.size ι
let mut dx' := dx'
let mut dw : W := 0
for i in [0:n.toNat] do
let j : ι := fromIdx ⟨n-i.toUSize-1,sorry_proof⟩
let xj := xs[j]
let dxw' := df' w j xj dx' dw
dx' := dxw'.1
dw := dxw'.2
(dx',dw)


/-- Reverse derivative of a for loop
WARNING: `dx'` and `dw` behave differently
- `df'` computes gradient in `dx'`
- `df'` computes update to gradient in `dw`
The value of `df'` should be:
df' = fun w i x dx' dw =>
((∇ (x':=x;dx'), f w i x'), (dw + ∇ (w':=w;dx'), f w' i x))
-/
def ForIn.arg_bf.revDeriv_dataArrayImpl [Index ι] [PlainDataType X] [PlainDataType W] [Zero X] [Zero W]
(init : X) (f : W → ι → X → X) (df' : W → ι → X → X → W → X×W) (w : W)
: X×(X→X×W) :=
Id.run do
let n := Index.size ι
-- forward pass
let mut xs : DataArray X := .mkEmpty n
let mut x := init
for i in fullRange ι do
xs := xs.push x
x := f w i x
let xs' : DataArrayN X ι := ⟨xs, sorry_proof⟩
(x, fun dx' => ForIn.arg_bf.revPass_dataArrayImpl df' w xs' dx')


/-- The do notation leaves the for loop body in a strange form `do pure PUnit.unit; pure <| ForInStep.yield (f w i y))`
Marking this theorem with `ftrans` is a bit of a hack. It normalizes the body to `ForInStep.yield (f w i y)`.
-/
@[ftrans]
theorem ForIn.forIn.arg_bf.revDerivM_rule_normalization [Index ι]
(init : W → X) (f : W → ι → X → X)
: revDerivM K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => do pure PUnit.unit; pure <| ForInStep.yield (f w i y)))
=
revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y))) := by rfl

-- Proof that the above theorem is true for the range [0:3] and function that does not break the for loop
example
(init : X → Y) (f : X → Nat → Y → m Y)
(hinit : HasAdjDiff K init) (hf : ∀ a, HasAdjDiffM K (fun (xy : X×Y) => f xy.1 a xy.2))
: revDerivM K (fun x => forIn [0:3] (init x) (fun i y => do pure (ForInStep.yield (← f x i y))))
=
(fun x => do
let ydinit := revCDeriv K init x
let ydf ← forIn [0:3] (ydinit.1, fun (dy:Y) => pure (f:=m') ((0:X),dy))
(fun a ydf => do
let ydf' ← revDerivM K (fun (xy : X×Y) => f xy.1 a xy.2) (x,ydf.1)
let df : Y → m' (X×Y) :=
fun dy : Y => do
let dxy ← ydf'.2 dy
let dxy' ← ydf.2 dxy.2
pure (dxy.1 + dxy'.1, dxy'.2)
pure (ForInStep.yield (ydf'.1, df)))
pure (ydf.1,
fun dy => do
let dxy ← ydf.2 dy
pure (dxy.1 + ydinit.2 dxy.2))) :=


/-- The do notation leaves the for loop body in a strange form `do pure PUnit.unit; pure <| ForInStep.yield (f w i y))`
Marking this theorem with `ftrans` is a bit of a hack. It normalizes the body to `ForInStep.yield (f w i y)`.
-/
@[ftrans]
theorem ForIn.forIn.arg_bf.revCDeriv_rule_normalization [Index ι]
(init : W → X) (f : W → ι → X → X)
: revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => do pure PUnit.unit; pure <| ForInStep.yield (f w i y)))
=
revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y))) := by rfl


@[ftrans]
theorem ForIn.forIn.arg_bf.revCDeriv_rule_def [Index ι] [PlainDataType X] [PlainDataType W]
(init : W → X) (f : W → ι → X → X)
(hinit : HasAdjDiff K init) (hf : ∀ i, HasAdjDiff K (fun (w,x) => f w i x))
: revCDeriv K (fun w => forIn (m:=Id) (fullRange ι) (init w) (fun i y => ForInStep.yield (f w i y)))
=
fun w => (Id.run do
let n := Index.size ι
let initdinit := revCDeriv K init w

-- forward pass
let mut xs : DataArray X := .mkEmpty n
let mut x := initdinit.1
for i in fullRange ι do
xs := xs.push x
x := f w i x
let xs' : DataArrayN X ι := ⟨xs, sorry_proof⟩

let revPassBody := hold fun w i x dx' dw =>
let dwx' := gradient K (fun (w',x') => f w' i x') (w,x) dx'
(dwx'.2, dw + dwx'.1)

(x,
fun dx' =>
-- reverse pass
let dxw' := ForIn.arg_bf.revPass_dataArrayImpl revPassBody w xs' dx'
initdinit.2 dxw'.1 + dxw'.2)) :=
by
simp [revCDeriv,forIn,Std.Range.forIn,Std.Range.forIn.loop,Std.Range.forIn.loop.match_1, revCDeriv]
ftrans
simp[add_assoc, revCDeriv]
sorry_proof


--------------------------------------------------------------------------------
-- ForInStep.yield -------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem ForInStep.yield.arg_a0.HasAdjDiff_rule
Expand Down Expand Up @@ -315,4 +402,9 @@ theorem ForInStep.done.arg_a0.revDerivM_rule
(.done ydf.1, fun y => ydf.2 y.val)
:= by sorry_proof


end OnSemiInnerProductSpace




Loading

0 comments on commit 4dfd822

Please sign in to comment.