Skip to content

Commit

Permalink
use tanh approximation for gelu
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 28, 2023
1 parent e61c1d9 commit 91039e0
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from dataclasses import dataclass
from functools import partial
import math
from typing import Tuple

Expand Down Expand Up @@ -213,7 +214,8 @@ def forward(self, inputs, padding_mask):
if self.config.activation_function_name == 'swish':
activation_fn = F.silu
elif self.config.activation_function_name == 'gelu':
activation_fn = F.gelu
# Use tanh approximation of GELU which is default for jax
activation_fn = partial(F.gelu, approximate='tanh')
else:
raise ValueError(
'Only "swish" and "gelu" are supported '
Expand Down

0 comments on commit 91039e0

Please sign in to comment.