diff --git a/SciLean/Modules/ML/Doodle.lean b/SciLean/Modules/ML/Doodle.lean index c0a11f1c..f35bb993 100644 --- a/SciLean/Modules/ML/Doodle.lean +++ b/SciLean/Modules/ML/Doodle.lean @@ -1,65 +1,67 @@ import SciLean import SciLean.Core.Meta.GenerateRevCDeriv' +import SciLean.Modules.ML.DenseLayer -namespace SciLean +open SciLean ML + +-- #profile_this_file variable - {K : Type} [RealScalar K] - {W : Type} [Vec K W] + {R : Type} [RealScalar R] + [PlainDataType R] + +set_default_scalar R -set_default_scalar K -variable {α β κ ι : Type} [Index.{0,0,0} α] [Index.{0,0,0} β] [Index.{0,0,0} κ] [Index.{0,0,0} ι] [PlainDataType K] [PlainDataType R] +variable (x : R^(Idx 20)) -variable (κ) -def denseLazy (weights : κ → ι → K) (bias : κ → K) (x : ι → K) (j : κ) : K := - ∑ i, weights j i * x i + bias j -variable {κ} +-- set_option profiler true in -#generate_revCDeriv' denseLazy weights bias x | j - prop_by unfold denseLazy; fprop - trans_by - unfold denseLazy +#check (revCDeriv R fun (w,w',w'',b,b',b'') => + x |> dense (Idx 5) w' b' + |> dense (Idx 10) w b + |> dense (Idx 20) w'' b'') + rewrite_by autodiff -variable (κ) -def dense (weights : DataArrayN K (κ×ι)) (bias : K^κ) (x : K^ι) : K^κ := - -- ⊞ j => ∑ i, weights[(j,i)] * x[i] + bias[j] - ⊞ j => denseLazy κ (fun j i => weights[(j,i)]) (fun j => bias[j]) (fun i => x[i]) j -variable {κ} -#generate_revCDeriv' dense weights bias x - prop_by unfold dense; fprop - trans_by unfold dense; autodiff +#check (revCDeriv R fun (w,w',w'',w''',w'''',b,b',b'',b''',b'''') => + x |> dense (Idx 5) w' b' + |> dense (Idx 10) w b + |> dense (Idx 20) w'' b'' + |> dense (Idx 20) w''' b''' + |> dense (Idx 30) w'''' b'''') + rewrite_by + ftrans only + ftrans + -- lsimp (config := {zeta:=false, singlePass:=true}) only + + + + + + -#eval 0 -#exit -set_option synthInstance.maxSize 20000 -#check denseLazy.arg_weightsbiasx_j.revCDeriv -#check dense.arg_weightsbiasx.revCDeriv -variable (x : K^(Idx 20)) -set_option profiler true in -#check (revCDeriv K fun (w,w',w'',w''',w'''',b,b',b'',b''',b'''') => - x |> dense (Idx 5) w' b' - |> dense (Idx 10) w b - |> dense (Idx 20) w'' b'' - |> dense (Idx 20) w''' b''' - |> dense (Idx 30) w'''' b'''') - rewrite_by - ftrans only - ftrans only - ftrans only - -- lsimp (config := {zeta:=false, singlePass:=true}) only + + + + + + + + +#exit + example (a : Nat) (b : Nat) : (b,a,a,0) + (b,a,a,a,a,a,a,a,a,b,b) + 0 = (b+b,a+a,a+a,a,a,a,a,a,a,b,b) := by simp