Skip to content

Commit

Permalink
differentiation rules for Function.foldl and some api for DataArray
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 14, 2023
1 parent d82425f commit 14f0781
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 40 deletions.
157 changes: 157 additions & 0 deletions SciLean/Core/Function.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import SciLean.Core.FunctionTransformations
import SciLean.Data.Function
import SciLean.Data.DataArray

open SciLean

set_option linter.unusedVariables false

variable
{α β ι : Type _}

section OnEnumType

variable [EnumType ι]

/-- Reverse derivative of `Function.foldl` w.r.t. `f` and `init`. It is implemented using `Array`.
TODO:
1. needs beter implementation but that requires refining EnumType and Index
2. add a version with DataArray
-/
def Function.foldl.fwdDeriv [Add α] [Add β]
(f df : ι → α) (dop : β → α → β → α → β×β) (init dinit : β) : β × β := Id.run do
let mut bdb := (init,dinit)
for i in fullRange ι do
bdb := dop bdb.1 (f i) bdb.2 (df i)
bdb

variable
{K : Type _} [IsROrC K]
{X : Type _} [Vec K X]
{Y : Type _} [Vec K Y]
{Z : Type _} [Vec K Z]
{W : Type _} [Vec K W]

@[fprop]
theorem Function.foldl.arg_finit.IsDifferentiable_rule
(f : W → (ι → X)) (op : Y → X → Y) (init : W → Y)
(hf : IsDifferentiable K f) (hop : IsDifferentiable K (fun (y,x) => op y x)) (hinit : IsDifferentiable K init)
: IsDifferentiable K (fun w => Function.foldl (f w) op (init w)) := by sorry_proof

@[ftrans]
theorem Function.foldl.arg_finit.fwdCDeriv_rule
(f : W → (ι → X)) (op : Y → X → Y) (init : W → Y)
(hf : IsDifferentiable K f) (hop : IsDifferentiable K (fun ((y,x) : Y×X) => op y x)) (hinit : IsDifferentiable K init)
: fwdCDeriv K (fun w => Function.foldl (f w) op (init w))
=
fun w dw =>
let fdf := fwdCDeriv K f w dw
let initdinit := fwdCDeriv K init w dw
let dop := fun y x dy dx => fwdCDeriv K (fun (y,x) => op y x) (y,x) (dy,dx)
Function.foldl.fwdDeriv fdf.1 fdf.2 dop initdinit.1 initdinit.2
:= by sorry_proof


end OnEnumType


section OnIndexType

variable [Index ι]

/-- Reverse derivative of `Function.foldl` w.r.t. `f` and `init`. It is implemented using `Array`.
TODO:
1. needs beter implementation but that requires refining EnumType and Index
2. add a version with DataArray
-/
def Function.foldl.revDeriv_arrayImpl [Add α] [Add β]
(f : ι → α) (op : β → α → β) (dop : β → α → β → β×α) (init : β) : β × (β → Array α×β) := Id.run do
let n := (Index.size ι).toNat
let mut bs : Array β := .mkEmpty n
let mut b := init
for i in fullRange ι do
bs := bs.push b
b := op b (f i)
(b,
fun db => Id.run do
let mut das : Array α := .mkEmpty n
let mut db : β := db
for i in [0:n] do
let j : ι := fromIdx ⟨n.toUSize-i.toUSize-1, sorry_proof⟩
let aj := f j
let bj := bs[n-i-1]'sorry_proof
let (db',da) := dop bj aj db
das := das.push da
db := db'
das := das.reverse
(das, db))


/-- Reverse derivative of `Function.foldl` w.r.t. `f` and `init`. It is implemented using `Array`.
TODO:
1. needs beter implementation but that requires refining EnumType and Index
2. add a version with DataArray
-/
def Function.foldl.revDeriv_dataArrayImpl [Add α] [Add β] [PlainDataType α] [PlainDataType β]
(f : ι → α) (op : β → α → β) (dop : β → α → β → β×α) (init : β) : β × (β → DataArrayN α ι×β) := Id.run do
let n := Index.size ι
let mut bs : DataArray β := .mkEmpty n
let mut b := init
for i in fullRange ι do
bs := bs.push b
b := op b (f i)
(b,
fun db => Id.run do
let mut das : DataArray α := .mkEmpty n
let mut db : β := db
for i in [0:n.toNat] do
let j' : Idx n := ⟨n-i.toUSize-1, sorry_proof⟩
let j : ι := fromIdx j'
let aj := f j
let bj := bs.get ⟨j'.1, sorry_proof⟩
let (db',da) := dop bj aj db
das := das.push da
db := db'
das := das.reverse
(⟨das, sorry_proof⟩, db))


variable
{K : Type _} [IsROrC K]
{X : Type _} [SemiInnerProductSpace K X]
{Y : Type _} [SemiInnerProductSpace K Y]
{Z : Type _} [SemiInnerProductSpace K Z]
{W : Type _} [SemiInnerProductSpace K W]


@[fprop]
theorem Function.foldl.arg_finit.HasAdjDiff_rule
(f : W → (ι → X)) (op : Y → X → Y) (init : W → Y)
(hf : HasAdjDiff K f) (hop : HasAdjDiff K (fun (y,x) => op y x)) (hinit : HasAdjDiff K init)
: HasAdjDiff K (fun w => Function.foldl (f w) op (init w)) := by sorry_proof

@[ftrans]
theorem Function.foldl.arg_finit.revCDeriv_rule [PlainDataType X] [PlainDataType Y]
(f : W → (ι → X)) (op : Y → X → Y) (init : W → Y)
(hf : HasAdjDiff K f) (hop : HasAdjDiff K (fun (y,x) => op y x)) (hinit : HasAdjDiff K init)
: revCDeriv K (fun w => Function.foldl (f w) op (init w))
=
fun w =>
let fdf := revCDeriv K f w
let initdinit := revCDeriv K init w
let dop := fun y x => gradient K (fun (y,x) => op y x) (y,x)
let ydy := Function.foldl.revDeriv_dataArrayImpl fdf.1 op dop initdinit.1
(ydy.1,
fun dy =>
let dfdinit := ydy.2 dy
let dw₁ := fdf.2 (fun i => dfdinit.1[i])
let dw₂ := initdinit.2 dfdinit.2
dw₁ + dw₂)
:= by sorry_proof

end OnIndexType


36 changes: 29 additions & 7 deletions SciLean/Data/DataArray/DataArray.lean
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,29 @@ def DataArray.set (arr : DataArray α) (i : Idx arr.size) (val : α) : DataArray
/-- Capacity of an array. The return type is `Squash Nat` as the capacity is is just an implementation detail and should not affect semantics of the program. -/
def DataArray.capacity (arr : DataArray α) : Squash USize := Quot.mk _ (pd.capacity (arr.byteData.size.toUSize))
/-- Makes sure that `arr` fits at least `n` elements of `α` -/
def DataArray.reserve (arr : DataArray α) (n : USize) : DataArray α :=
if (pd.capacity (arr.byteData.size.toUSize)) ≤ n then
def DataArray.reserve (arr : DataArray α) (capacity : USize) : DataArray α :=
if (pd.capacity (arr.byteData.size.toUSize)) ≤ capacity then
arr
else Id.run do
let newBytes := pd.bytes n
let newBytes := pd.bytes capacity
let mut arr' : DataArray α := ⟨ByteArray.mkEmpty newBytes.toNat, arr.size, sorry_proof⟩
-- copy over the old data
for i in fullRange (Idx arr.size) do
arr' := ⟨arr'.byteData.push 0, arr.size, sorry_proof⟩
arr' := arr'.set i (arr.get i)
arr'

def DataArray.mkEmpty (capacity : USize) : DataArray α := Id.run do
let mut a : DataArray α :=
{ byteData := .mkEmpty 0
size := 0
h_size := by sorry_proof }
a.reserve capacity


def DataArray.drop (arr : DataArray α) (k : USize) : DataArray α := ⟨arr.byteData, arr.size - k, sorry_proof⟩

def DataArray.push (arr : DataArray α) (k : USize := 1) (val : α) : DataArray α := Id.run do
def DataArray.push (arr : DataArray α) (val : α) (k : USize := 1) : DataArray α := Id.run do
let oldSize := arr.size
let newSize := arr.size + k
let mut arr' := arr.reserve newSize
Expand All @@ -78,6 +85,22 @@ Currently this is inconsistent, we need to turn DataArray into quotient!
-/
theorem DataArray.ext (d d' : DataArray α) : (h : d.size = d'.size) → (∀ i, d.get i = d'.get (h ▸ i)) → d = d' := sorry_proof

def DataArray.swap (arr : DataArray α) (i j : Idx arr.size) : DataArray α :=
let ai := arr.get i
let aj := arr.get j
let arr := arr.set i aj
let arr := arr.set ⟨j.1, sorry_proof⟩ ai
arr

def DataArray.reverse (arr : DataArray α) : DataArray α := Id.run do
let mut arr := arr
let n := arr.size
for i in [0:n.toNat/2] do
let i' : Idx arr.size := ⟨i.toUSize, sorry_proof⟩
let j' : Idx arr.size := ⟨n - i.toUSize - 1, sorry_proof⟩
arr := arr.swap i' j'
arr

@[irreducible]
def DataArray.intro (f : ι → α) : DataArray α := Id.run do
let bytes := (pd.bytes (Index.size ι))
Expand Down Expand Up @@ -115,9 +138,8 @@ instance : ArrayType (DataArrayN α ι) ι α where

instance : ArrayTypeNotation (DataArrayN α ι) ι α := ⟨⟩

-- These instance might clasth with previous ones
instance : PushElem (λ n => DataArrayN α (Idx n)) α where
pushElem k val xs := ⟨xs.1.push k val, sorry_proof⟩
pushElem k val xs := ⟨xs.1.push val k, sorry_proof⟩

instance : DropElem (λ n => DataArrayN α (Idx n)) α where
dropElem k xs := ⟨xs.1.drop k, sorry_proof⟩
Expand All @@ -132,7 +154,7 @@ instance : LinearArrayType (λ n => DataArrayN α (Idx n)) α where
reserveElem_id := sorry_proof


instance {Cont ι α : Type} [ArrayType Cont ι α] [Index ι] [Inhabited α] [pd : PlainDataType α] : PlainDataType Cont where
instance {Cont ι α : Type} [ArrayType Cont ι α] [Index ι] [Inhabited α] [pd : PlainDataType α] : PlainDataType Cont where
btype := match pd.btype with
| .inl αBitType =>
-- TODO: Fixme !!!!
Expand Down
33 changes: 0 additions & 33 deletions SciLean/Data/Function.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,38 +32,5 @@ def Function.joinlD (f : ι → α) (op : α → α → α) (default : α) : α
return a


/-- Reverse derivative of `Function.foldl` w.r.t. `f` and `init`. It is implemented using `Array`.

TODO:
1. needs beter implementation but that requires refining EnumType and Index
2. add a version with DataArray
-/
def Function.foldl.revDeriv_arrayImpl {α β : Type} [Add α] [Add β] [ToString β]
(f : ι → α) (op : β → α → β) (dop : β → α → β → β×α) (init : β) : β × (β → Array α×β) := Id.run do
let n := (Index.size ι).toNat
let mut bs : Array β := .mkEmpty n
let mut b := init
for i in fullRange ι do
bs := bs.push b
b := op b (f i)
dbg_trace bs
(b,
fun db => Id.run do
let mut das : Array α := .mkEmpty n
let mut db : β := db
for i in [0:n] do
let j : ι := fromIdx ⟨n.toUSize-i.toUSize-1, sorry_proof⟩
let aj := f j
let bj := bs[n-i-1]'sorry_proof
let (db',da) := dop bj aj db
das := das.push da
db := db'
das := das.reverse
(das, db))



#eval Function.foldl.revDeriv_arrayImpl (β:=USize) (fun i : Idx 5 => i.1) (fun s x => s + x*x*x) (fun s x d => (d, 3*x*x*d)) 0 |>.snd 1

#eval Function.foldl.revDeriv_arrayImpl (β:=Float) (fun i : Idx 3 => (i.1+5).toNat.toFloat) (fun s x => s/x) (fun s x d => (d/x, -s*d/(x*x))) 1 |>.snd 1

0 comments on commit 14f0781

Please sign in to comment.