From 574489f7cf044a415ea11f6190f81a143cae27f5 Mon Sep 17 00:00:00 2001 From: lecopivo Date: Tue, 5 Dec 2023 13:22:02 -0500 Subject: [PATCH] softMax layer working now --- SciLean/Core/Functions/Exp.lean | 2 +- SciLean/Modules/ML/SoftMax.lean | 32 ++++++++++---------------------- 2 files changed, 11 insertions(+), 23 deletions(-) diff --git a/SciLean/Core/Functions/Exp.lean b/SciLean/Core/Functions/Exp.lean index 1f4103b4..bebf3f63 100644 --- a/SciLean/Core/Functions/Exp.lean +++ b/SciLean/Core/Functions/Exp.lean @@ -44,5 +44,5 @@ by #generate_revDeriv exp x prop_by unfold HasAdjDiff; constructor; fprop; ftrans; fprop - trans_by unfold revDeriv; ftrans; ftrans + abbrev trans_by unfold revDeriv; ftrans; ftrans diff --git a/SciLean/Modules/ML/SoftMax.lean b/SciLean/Modules/ML/SoftMax.lean index 25ef61e6..3b89e6ad 100644 --- a/SciLean/Modules/ML/SoftMax.lean +++ b/SciLean/Modules/ML/SoftMax.lean @@ -1,4 +1,6 @@ import SciLean.Core +import SciLean.Core.Functions.Exp +import SciLean.Core.Meta.GenerateRevDeriv import SciLean.Data.DataArray import SciLean.Data.Prod import Mathlib @@ -12,28 +14,14 @@ set_default_scalar R def softMax {ι} [Index ι] (r : R) (x : R^ι) : R^ι := - let wx := Function.repeatIdx (init:=((0:R),x)) - fun (i : ι) (w,x) => - let xi := x[i] - let xi' := Scalar.exp (r*xi) - (w + xi', setElem x i (xi * xi')) - -- have : ∀ x :R, x ≠ 0 := by sorry_proof - wx.2 + let x := ArrayType.map (fun xi => Scalar.exp (r*xi)) x + let w := ∑ i, x[i] + (1/w) • x -@[fprop] -theorem Scalar.exp.arg_x.HasAdjDiff_rule - {R K} [Scalar R K] {W} [SemiInnerProductSpace K W] - (x : W → K) (hx : HasAdjDiff K x) - : HasAdjDiff K (fun w => Scalar.exp (x w)) := by sorry_proof - --- set_option trace.Meta.Tactic.fprop.discharge true --- set_option trace.Meta.Tactic.fprop.step true --- set_option trace.Meta.Tactic.fprop.apply true --- set_option trace.Meta.Tactic.fprop.rewrite true --- set_option trace.Meta.Tactic.fprop.unify true -set_option trace.Meta.Tactic.ftrans.step true set_option trace.Meta.Tactic.simp.discharge true --- #generate_revDeriv softMax r x --- prop_by unfold softMax; sorry_proof --fprop --- trans_by unfold softMax; ftrans +set_option trace.Meta.Tactic.simp.unify true +#generate_revDeriv softMax x + prop_by unfold softMax; fprop + trans_by unfold softMax; ftrans +