Skip to content

Commit

Permalink
gelu activation function
Browse files Browse the repository at this point in the history
  • Loading branch information
lecopivo committed Dec 5, 2023
1 parent fd7935c commit 4063def
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 5 deletions.
9 changes: 9 additions & 0 deletions SciLean/Core/FloatAsReal.lean
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ instance : RealScalar Float where
tan x := x.tan
tan_def := sorry_proof

asin x := x.asin
asin_def := sorry_proof

acos x := x.acos
acos_def := sorry_proof

atan x := x.atan
atan_def := sorry_proof

exp x := x.exp
exp_def := sorry_proof

Expand Down
48 changes: 48 additions & 0 deletions SciLean/Core/Functions/Exp.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import SciLean.Core.FunctionTransformations
import SciLean.Core.Meta.GenerateRevDeriv

open ComplexConjugate

namespace SciLean.Scalar

variable
{R C} [Scalar R C]
{W} [Vec C W]
{U} [SemiInnerProductSpace C U]


--------------------------------------------------------------------------------
-- Exp -------------------------------------------------------------------------
--------------------------------------------------------------------------------

@[fprop]
theorem exp.arg_x.IsDifferentiable_rule
(x : W → C) (hx : IsDifferentiable C x)
: IsDifferentiable C fun w => exp (x w) := sorry_proof

@[ftrans]
theorem exp.arg_x.ceriv_rule
(x : W → C) (hx : IsDifferentiable C x)
: cderiv C (fun w => exp (x w))
=
fun w dw =>
let xdx := fwdCDeriv C x w dw
let e := exp xdx.1
xdx.2 * e := sorry_proof

@[ftrans]
theorem exp.arg_x.fwdCDeriv_rule
(x : W → C) (hx : IsDifferentiable C x)
: fwdCDeriv C (fun w => exp (x w))
=
fun w dw =>
let xdx := fwdCDeriv C x w dw
let e := exp xdx.1
(e, xdx.2 * e) :=
by
unfold fwdCDeriv; ftrans; rfl

#generate_revDeriv exp x
prop_by unfold HasAdjDiff; constructor; fprop; ftrans; fprop
trans_by unfold revDeriv; ftrans; ftrans

13 changes: 11 additions & 2 deletions SciLean/Core/Functions/Trigonometric.lean
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ by

#generate_revDeriv sin x
prop_by unfold HasAdjDiff; constructor; fprop; ftrans; fprop
trans_by unfold revDeriv; ftrans; ftrans
abbrev trans_by unfold revDeriv; ftrans; ftrans


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -120,4 +120,13 @@ by

#generate_revDeriv tanh x
prop_by unfold HasAdjDiff; constructor; fprop; ftrans; fprop
trans_by unfold revDeriv; ftrans; ftrans
abbrev trans_by
unfold revDeriv; ftrans; ftrans
enter [x]
-- we just need to replace `tanh x` with `t`, there should be a tactic for it
-- or common subexpression optimization should do it
equals (let t := tanh x;
let dt := 1 - t ^ 2;
(t, fun y => (starRingEnd K) dt * y)) => rfl


21 changes: 21 additions & 0 deletions SciLean/Core/Objects/Scalar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import Mathlib.Data.Complex.Exponential
import Mathlib.Analysis.Complex.Basic
import Mathlib.Analysis.SpecialFunctions.Pow.Complex
import Mathlib.Analysis.SpecialFunctions.Pow.Real
import Mathlib.Analysis.SpecialFunctions.Trigonometric.Inverse
import Mathlib.Analysis.SpecialFunctions.Trigonometric.Arctan

import SciLean.Util.SorryProof
import SciLean.Tactic.FTrans.Init
Expand Down Expand Up @@ -91,6 +93,16 @@ See `Scalar` for motivation for this class.
class RealScalar (R : semiOutParam (Type _)) extends Scalar R R where
is_real : ∀ x : R, im x = 0

asin (x : R) : R
asin_def : ∀ x, toReal (asin x) = Real.arcsin (toReal x)

acos (x : R) : R
acos_def : ∀ x, toReal (acos x) = Real.arccos (toReal x)

atan (x : R) : R
atan_def : ∀ x, toReal (atan x) = Real.arctan (toReal x)

def RealScalar.pi [RealScalar R] : R := RealScalar.acos (-1)

instance {R K} [Scalar R K] : HPow K K K := ⟨fun x y => Scalar.pow x y⟩

Expand Down Expand Up @@ -166,6 +178,15 @@ noncomputable instance : RealScalar ℝ where
tan x := x.tan
tan_def := by intros; simp[Real.tan]; sorry_proof

asin x := x.arcsin
asin_def := by intros; simp

acos x := x.arccos
acos_def := by intros; simp

atan x := x.arctan
atan_def := by intros; simp

exp x := x.exp
exp_def := by intros; simp[Real.exp]; sorry_proof

Expand Down
20 changes: 20 additions & 0 deletions SciLean/Modules/ML/Activation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import SciLean.Core
import SciLean.Core.Functions.Trigonometric
import SciLean.Core.FloatAsReal
import SciLean.Core.Meta.GenerateRevDeriv

namespace SciLean.ML

variable {R : Type} [RealScalar R]

open Scalar RealScalar

def gelu (x : R) : R :=
let c := sqrt (2/pi)
x * (1 + tanh (c * x * (1 + 0.044715 * x^2)))

#generate_revDeriv gelu x
prop_by unfold gelu; fprop
trans_by
unfold gelu
ftrans
5 changes: 4 additions & 1 deletion SciLean/Tactic/FTrans/Simp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import Mathlib.Algebra.SMulWithZero
namespace SciLean

-- basic algebraic operations
attribute [ftrans_simp] add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self mul_one one_mul one_smul
attribute [ftrans_simp] add_zero zero_add sub_zero zero_sub sub_self neg_zero mul_zero zero_mul zero_smul smul_zero smul_eq_mul smul_neg eq_self iff_self mul_one one_mul one_smul tsub_zero pow_one

-- simps theorems for `Nat`
attribute [ftrans_simp] Nat.succ_sub_succ_eq_sub

-- simp theorems for `Prod`
attribute [ftrans_simp] Prod.mk.eta Prod.fst_zero Prod.snd_zero Prod.mk_add_mk Prod.mk_mul_mk Prod.smul_mk Prod.mk_sub_mk Prod.neg_mk Prod.vadd_mk
Expand Down
5 changes: 3 additions & 2 deletions examples/MNISTClassifier/Model.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import SciLean.Core.FloatAsReal
import SciLean.Modules.ML.Dense
import SciLean.Modules.ML.Convolution
import SciLean.Modules.ML.Pool
import SciLean.Modules.ML.Activation

open SciLean
open IO FS System
Expand All @@ -16,10 +17,10 @@ open ML ArrayType in
def model (w x) :=
(fun ((w₁,b₁),(w₂,b₂),(w₃,b₃)) (x : Float^[1,28,28]) =>
x |> conv2d 32 1 w₁ b₁
|> map (fun x => x^2)
|> map gelu
|> avgPool
|> dense 100 w₂ b₂
|> map (fun x => x^2)
|> map gelu
|> dense 10 w₃ b₃) w x
-- |> softMax

Expand Down

0 comments on commit 4063def

Please sign in to comment.