From dd18c83a782bb945ab8600539be68514eb7597da Mon Sep 17 00:00:00 2001 From: lecopivo Date: Fri, 10 Nov 2023 17:06:10 -0500 Subject: [PATCH] messing around with convolution --- SciLean/Modules/ML/Convolution.lean | 126 ++++++++++++++-------------- 1 file changed, 64 insertions(+), 62 deletions(-) diff --git a/SciLean/Modules/ML/Convolution.lean b/SciLean/Modules/ML/Convolution.lean index c20c2177..303819ea 100644 --- a/SciLean/Modules/ML/Convolution.lean +++ b/SciLean/Modules/ML/Convolution.lean @@ -3,6 +3,8 @@ import SciLean.Data.DataArray import SciLean.Data.Prod import SciLean.Core.Meta.GenerateRevCDeriv' +import SciLean.Core.FloatAsReal + namespace SciLean @@ -15,21 +17,35 @@ variable {α β κ ι} [Index α] [Index κ] [Index β] [Index ι] [PlainDataTyp #check AddAction --- This is probably broken when overflow happens -def idxShift (j : Idx m) : Idx n ≃ Idx n where - toFun i := ⟨((n + i.1 + j.1) - (m >>> 1)) % n, sorry_proof⟩ - invFun i := ⟨((n + i.1 + (m >>> 1)) - j.1) % n, sorry_proof⟩ - left_inv := sorry_proof - right_inv := sorry_proof +variable (κ) +/-- + @param α indexing set of + -/ +def convolutionLazy [Nonempty ι] + (indexAction : κ → ι≃ι) + (weights : κ → R) (bias : ι → R) (x : ι → R) + (i : ι) : R := + ∑ j, weights j * x (indexAction j i) + bias i + +variable {κ} + +#generate_revCDeriv' convolutionLazy weights bias x | i + prop_by unfold convolutionLazy; fprop + trans_by + unfold convolutionLazy + autodiff + +-------------------------------------------------------------------------------- +-- Index actions used for concrete convolutions -------------------------------- +-------------------------------------------------------------------------------- -- This is probably broken when overflow happens -def idxShift' (j : Idx' m₁ m₂) : Idx n ≃ Idx n where +def idxShift (j : Idx m) : Idx n ≃ Idx n where toFun i := ⟨((n + i.1 + j.1) - (m >>> 1)) % n, sorry_proof⟩ invFun i := ⟨((n + i.1 + (m >>> 1)) - j.1) % n, sorry_proof⟩ left_inv := sorry_proof right_inv := sorry_proof - -- This is probably broken when overflow happens def idx2Shift (j : Idx m × Idx m') : Idx n × Idx n' ≃ Idx n × Idx n' where toFun i := (idxShift j.1 i.1, idxShift j.2 i.2) @@ -38,13 +54,45 @@ def idx2Shift (j : Idx m × Idx m') : Idx n × Idx n' ≃ Idx n × Idx n' where right_inv := sorry_proof --- This is probably broken when overflow happens -def idx2Shift' (j : Idx' m₁ m₂ × Idx' m₁' m₂') : Idx n × Idx n' ≃ Idx n × Idx n' where - toFun i := (idxShift j.1 i.1, idxShift j.2 i.2) - invFun i := ((idxShift j.1).invFun i.1, (idxShift j.2).invFun i.2) - left_inv := sorry_proof - right_inv := sorry_proof +-------------------------------------------------------------------------------- +-- Concrete convolutions over arrays ------------------------------------------- +-------------------------------------------------------------------------------- + +def conv1d {m n} [Nonempty (Idx n)] + (weights : R ^ Idx m) (bias : R ^ Idx n) (x : R ^ Idx n) + : R ^ Idx n := + introElem fun ij => convolutionLazy (Idx m) idxShift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij + +#generate_revCDeriv' conv1d weights bias x + prop_by unfold conv1d; fprop + trans_by + unfold conv1d + autodiff +def conv2d {m₁ m₂ n₁ n₂} [Nonempty (Idx n₁)] [Nonempty (Idx n₂)] + (weights : R ^ (Idx m₁ × Idx m₂)) (bias : R ^ (Idx n₁ × Idx n₂)) (x : R ^ (Idx n₁ × Idx n₂)) + : R ^ (Idx n₁ × Idx n₂) := + introElem fun ij => convolutionLazy (Idx m₁ × Idx m₂) idx2Shift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij + +#generate_revCDeriv' conv2d weights bias x + prop_by unfold conv2d; fprop + trans_by + unfold conv2d + autodiff + +-- #check fun (n : Nat) ≃> n +ᵥ n + +variable (α κ) + + +def x := ⊞ (i : Idx 10) => if i == 0 then 1.0 else 0.0 +def w := ⊞ (i : Idx 3) => if i == 0 then 0.25 else if i == 1 then 0.5 else 0.25 + +instance : CoeDep (Array Float) a (Float ^ (Idx (no_index a.size.toUSize))) := sorry + +#eval conv1d ⊞[0.25,0.5,0.25] 0 ⊞[0.0,0,0,1,1,1,1,0,0,0] + +-- #eval conv1d ⊞[[0.0,0.125,0.0],[0.125,0.5,0.125],[0.0,0.125,0.0]] 0 ⊞[0.0,0,0,1,1,1,1,0,0,0] def idxSplit2 (h : n % 2 = 0) : Idx n ≃ Idx (n/2) × Idx 2 where @@ -65,55 +113,9 @@ def idx2Split2 (h : n % 2 = 0 ∧ n' % 2 = 0) : Idx n × Idx n' ≃ (Idx (n/2) left_inv := sorry_proof right_inv := sorry_proof -#check Function.invFun (fun (i : Idx 10) => idxSplit2 sorry i) - rewrite_by ftrans - - -variable [Nonempty (Idx n)] -example : Function.invFun (fun (i : Idx n) => idxSplit2 sorry i) - = - fun j => (idxSplit2 sorry).invFun j := by ftrans - - -variable (κ) - -/-- - @param α indexing set of - -/ -def convolutionLazy [Nonempty ι] - (indexAction : κ → ι≃ι) - (weights : κ → R) (bias : ι → R) (x : ι → R) - (i : ι) : R := - ∑ j, weights j * x (indexAction j i) + bias i - -variable {κ} - -#generate_revCDeriv' convolutionLazy weights bias x | i - prop_by unfold convolutionLazy; fprop - trans_by - unfold convolutionLazy - autodiff - -#eval 0 - -#check convolutionLazy.arg_weightsbiasx_i.revCDeriv - -def conv2d {m₁ m₂ n₁ n₂} [Nonempty (Idx n₁)] [Nonempty (Idx n₂)] - (weights : R ^ (Idx m₁ × Idx m₂)) (bias : R ^ (Idx n₁ × Idx n₂)) (x : R ^ (Idx n₁ × Idx n₂)) - : R ^ (Idx n₁ × Idx n₂) := - introElem fun ij => convolutionLazy (Idx m₁ × Idx m₂) idx2Shift (fun i => weights[i]) (fun i => bias[i]) (fun i => x[i]) ij - - -#generate_revCDeriv' conv2d weights bias x - prop_by unfold conv2d; fprop - trans_by - unfold conv2d - autodiff - --- #check fun (n : Nat) ≃> n +ᵥ n - -variable (α κ) +#exit +#eval ⊞ i : Idx 10 => i.toFloat def convolution (indexAction : κ → ι≃ι) (weights : DataArrayN R (α×κ))