Skip to content

Commit

Permalink
messing around with convolution
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 10, 2023
1 parent ba8615a commit dd18c83
Showing 1 changed file with 64 additions and 62 deletions.
126 changes: 64 additions & 62 deletions SciLean/Modules/ML/Convolution.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import SciLean.Data.DataArray
import SciLean.Data.Prod
import SciLean.Core.Meta.GenerateRevCDeriv'

import SciLean.Core.FloatAsReal


namespace SciLean

Expand All @@ -15,21 +17,35 @@ variable {α β κ ι} [Index α] [Index κ] [Index β] [Index ι] [PlainDataTyp

#check AddAction

-- This is probably broken when overflow happens
def idxShift (j : Idx m) : Idx n ≃ Idx n where
toFun i := ⟨((n + i.1 + j.1) - (m >>> 1)) % n, sorry_proof⟩
invFun i := ⟨((n + i.1 + (m >>> 1)) - j.1) % n, sorry_proof⟩
left_inv := sorry_proof
right_inv := sorry_proof
variable (κ)
/--
@param α indexing set of
-/
def convolutionLazy [Nonempty ι]
(indexAction : κ → ι≃ι)
(weights : κ → R) (bias : ι → R) (x : ι → R)
(i : ι) : R :=
∑ j, weights j * x (indexAction j i) + bias i

variable {κ}

#generate_revCDeriv' convolutionLazy weights bias x | i
prop_by unfold convolutionLazy; fprop
trans_by
unfold convolutionLazy
autodiff

--------------------------------------------------------------------------------
-- Index actions used for concrete convolutions --------------------------------
--------------------------------------------------------------------------------

-- This is probably broken when overflow happens
def idxShift' (j : Idx' m₁ m₂) : Idx n ≃ Idx n where
def idxShift (j : Idx m) : Idx n ≃ Idx n where
toFun i := ⟨((n + i.1 + j.1) - (m >>> 1)) % n, sorry_proof⟩
invFun i := ⟨((n + i.1 + (m >>> 1)) - j.1) % n, sorry_proof⟩
left_inv := sorry_proof
right_inv := sorry_proof


-- This is probably broken when overflow happens
def idx2Shift (j : Idx m × Idx m') : Idx n × Idx n' ≃ Idx n × Idx n' where
toFun i := (idxShift j.1 i.1, idxShift j.2 i.2)
Expand All @@ -38,13 +54,45 @@ def idx2Shift (j : Idx m × Idx m') : Idx n × Idx n' ≃ Idx n × Idx n' where
right_inv := sorry_proof


-- This is probably broken when overflow happens
def idx2Shift' (j : Idx' m₁ m₂ × Idx' m₁' m₂') : Idx n × Idx n' ≃ Idx n × Idx n' where
toFun i := (idxShift j.1 i.1, idxShift j.2 i.2)
invFun i := ((idxShift j.1).invFun i.1, (idxShift j.2).invFun i.2)
left_inv := sorry_proof
right_inv := sorry_proof
--------------------------------------------------------------------------------
-- Concrete convolutions over arrays -------------------------------------------
--------------------------------------------------------------------------------

def conv1d {m n} [Nonempty (Idx n)]
(weights : R ^ Idx m) (bias : R ^ Idx n) (x : R ^ Idx n)
: R ^ Idx n :=
introElem fun ij => convolutionLazy (Idx m) idxShift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij

#generate_revCDeriv' conv1d weights bias x
prop_by unfold conv1d; fprop
trans_by
unfold conv1d
autodiff

def conv2d {m₁ m₂ n₁ n₂} [Nonempty (Idx n₁)] [Nonempty (Idx n₂)]
(weights : R ^ (Idx m₁ × Idx m₂)) (bias : R ^ (Idx n₁ × Idx n₂)) (x : R ^ (Idx n₁ × Idx n₂))
: R ^ (Idx n₁ × Idx n₂) :=
introElem fun ij => convolutionLazy (Idx m₁ × Idx m₂) idx2Shift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij

#generate_revCDeriv' conv2d weights bias x
prop_by unfold conv2d; fprop
trans_by
unfold conv2d
autodiff

-- #check fun (n : Nat) ≃> n +ᵥ n

variable (α κ)


def x := ⊞ (i : Idx 10) => if i == 0 then 1.0 else 0.0
def w := ⊞ (i : Idx 3) => if i == 0 then 0.25 else if i == 1 then 0.5 else 0.25

instance : CoeDep (Array Float) a (Float ^ (Idx (no_index a.size.toUSize))) := sorry

#eval conv1d ⊞[0.25,0.5,0.25] 0 ⊞[0.0,0,0,1,1,1,1,0,0,0]

-- #eval conv1d ⊞[[0.0,0.125,0.0],[0.125,0.5,0.125],[0.0,0.125,0.0]] 0 ⊞[0.0,0,0,1,1,1,1,0,0,0]


def idxSplit2 (h : n % 2 = 0) : Idx n ≃ Idx (n/2) × Idx 2 where
Expand All @@ -65,55 +113,9 @@ def idx2Split2 (h : n % 2 = 0 ∧ n' % 2 = 0) : Idx n × Idx n' ≃ (Idx (n/2)
left_inv := sorry_proof
right_inv := sorry_proof

#check Function.invFun (fun (i : Idx 10) => idxSplit2 sorry i)
rewrite_by ftrans


variable [Nonempty (Idx n)]
example : Function.invFun (fun (i : Idx n) => idxSplit2 sorry i)
=
fun j => (idxSplit2 sorry).invFun j := by ftrans


variable (κ)

/--
@param α indexing set of
-/
def convolutionLazy [Nonempty ι]
(indexAction : κ → ι≃ι)
(weights : κ → R) (bias : ι → R) (x : ι → R)
(i : ι) : R :=
∑ j, weights j * x (indexAction j i) + bias i

variable {κ}


#generate_revCDeriv' convolutionLazy weights bias x | i
prop_by unfold convolutionLazy; fprop
trans_by
unfold convolutionLazy
autodiff

#eval 0

#check convolutionLazy.arg_weightsbiasx_i.revCDeriv

def conv2d {m₁ m₂ n₁ n₂} [Nonempty (Idx n₁)] [Nonempty (Idx n₂)]
(weights : R ^ (Idx m₁ × Idx m₂)) (bias : R ^ (Idx n₁ × Idx n₂)) (x : R ^ (Idx n₁ × Idx n₂))
: R ^ (Idx n₁ × Idx n₂) :=
introElem fun ij => convolutionLazy (Idx m₁ × Idx m₂) idx2Shift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij


#generate_revCDeriv' conv2d weights bias x
prop_by unfold conv2d; fprop
trans_by
unfold conv2d
autodiff

-- #check fun (n : Nat) ≃> n +ᵥ n

variable (α κ)
#exit
#eval ⊞ i : Idx 10 => i.toFloat

def convolution
(indexAction : κ → ι≃ι) (weights : DataArrayN R (α×κ))
Expand Down

0 comments on commit dd18c83

Please sign in to comment.