Skip to content

Commit

Permalink
softMax layer working now
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 5, 2023
1 parent f0b8ba1 commit 574489f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 23 deletions.
2 changes: 1 addition & 1 deletion SciLean/Core/Functions/Exp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ by

#generate_revDeriv exp x
prop_by unfold HasAdjDiff; constructor; fprop; ftrans; fprop
trans_by unfold revDeriv; ftrans; ftrans
abbrev trans_by unfold revDeriv; ftrans; ftrans

32 changes: 10 additions & 22 deletions SciLean/Modules/ML/SoftMax.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import SciLean.Core
import SciLean.Core.Functions.Exp
import SciLean.Core.Meta.GenerateRevDeriv
import SciLean.Data.DataArray
import SciLean.Data.Prod
import Mathlib
Expand All @@ -12,28 +14,14 @@ set_default_scalar R

def softMax
{ι} [Index ι] (r : R) (x : R^ι) : R^ι :=
let wx := Function.repeatIdx (init:=((0:R),x))
fun (i : ι) (w,x) =>
let xi := x[i]
let xi' := Scalar.exp (r*xi)
(w + xi', setElem x i (xi * xi'))
-- have : ∀ x :R, x ≠ 0 := by sorry_proof
wx.2
let x := ArrayType.map (fun xi => Scalar.exp (r*xi)) x
let w := ∑ i, x[i]
(1/w) • x

@[fprop]
theorem Scalar.exp.arg_x.HasAdjDiff_rule
{R K} [Scalar R K] {W} [SemiInnerProductSpace K W]
(x : W → K) (hx : HasAdjDiff K x)
: HasAdjDiff K (fun w => Scalar.exp (x w)) := by sorry_proof


-- set_option trace.Meta.Tactic.fprop.discharge true
-- set_option trace.Meta.Tactic.fprop.step true
-- set_option trace.Meta.Tactic.fprop.apply true
-- set_option trace.Meta.Tactic.fprop.rewrite true
-- set_option trace.Meta.Tactic.fprop.unify true
set_option trace.Meta.Tactic.ftrans.step true
set_option trace.Meta.Tactic.simp.discharge true
-- #generate_revDeriv softMax r x
-- prop_by unfold softMax; sorry_proof --fprop
-- trans_by unfold softMax; ftrans
set_option trace.Meta.Tactic.simp.unify true
#generate_revDeriv softMax x
prop_by unfold softMax; fprop
trans_by unfold softMax; ftrans

0 comments on commit 574489f

Please sign in to comment.