Skip to content

Commit

Permalink
clean up of ML doodle file
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Nov 23, 2023
1 parent 6f77b13 commit 65fa816
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions SciLean/Modules/ML/Doodle.lean
Original file line number Diff line number Diff line change
@@ -1,65 +1,67 @@
import SciLean
import SciLean.Core.Meta.GenerateRevCDeriv'
import SciLean.Modules.ML.DenseLayer

namespace SciLean
open SciLean ML

-- #profile_this_file

variable
{K : Type} [RealScalar K]
{W : Type} [Vec K W]
{R : Type} [RealScalar R]
[PlainDataType R]

set_default_scalar R

set_default_scalar K

variable {α β κ ι : Type} [Index.{0,0,0} α] [Index.{0,0,0} β] [Index.{0,0,0} κ] [Index.{0,0,0} ι] [PlainDataType K] [PlainDataType R]
variable (x : R^(Idx 20))

variable (κ)
def denseLazy (weights : κ → ι → K) (bias : κ → K) (x : ι → K) (j : κ) : K :=
∑ i, weights j i * x i + bias j
variable {κ}
-- set_option profiler true in

#generate_revCDeriv' denseLazy weights bias x | j
prop_by unfold denseLazy; fprop
trans_by
unfold denseLazy
#check (revCDeriv R fun (w,w',w'',b,b',b'') =>
x |> dense (Idx 5) w' b'
|> dense (Idx 10) w b
|> dense (Idx 20) w'' b'')
rewrite_by
autodiff

variable (κ)
def dense (weights : DataArrayN K (κ×ι)) (bias : K^κ) (x : K^ι) : K^κ :=
-- ⊞ j => ∑ i, weights[(j,i)] * x[i] + bias[j]
⊞ j => denseLazy κ (fun j i => weights[(j,i)]) (fun j => bias[j]) (fun i => x[i]) j
variable {κ}

#generate_revCDeriv' dense weights bias x
prop_by unfold dense; fprop
trans_by unfold dense; autodiff
#check (revCDeriv R fun (w,w',w'',w''',w'''',b,b',b'',b''',b'''') =>
x |> dense (Idx 5) w' b'
|> dense (Idx 10) w b
|> dense (Idx 20) w'' b''
|> dense (Idx 20) w''' b'''
|> dense (Idx 30) w'''' b'''')
rewrite_by
ftrans only
ftrans
-- lsimp (config := {zeta:=false, singlePass:=true}) only









#eval 0

#exit
set_option synthInstance.maxSize 20000


#check denseLazy.arg_weightsbiasx_j.revCDeriv
#check dense.arg_weightsbiasx.revCDeriv

variable (x : K^(Idx 20))

set_option profiler true in

#check (revCDeriv K fun (w,w',w'',w''',w'''',b,b',b'',b''',b'''') =>
x |> dense (Idx 5) w' b'
|> dense (Idx 10) w b
|> dense (Idx 20) w'' b''
|> dense (Idx 20) w''' b'''
|> dense (Idx 30) w'''' b'''')
rewrite_by
ftrans only
ftrans only
ftrans only
-- lsimp (config := {zeta:=false, singlePass:=true}) only










#exit

example (a : Nat) (b : Nat)
: (b,a,a,0) + (b,a,a,a,a,a,a,a,a,b,b) + 0 = (b+b,a+a,a+a,a,a,a,a,a,a,b,b) := by simp

Expand Down

0 comments on commit 65fa816

Please sign in to comment.