Skip to content

Commit

Permalink
Merge pull request #196 from simudt/feat/smooth-relu-activation-exper…
Browse files Browse the repository at this point in the history
…imental-layer
  • Loading branch information
kyegomez authored Apr 12, 2024
2 parents 0dd3c09 + 161e48b commit 516c82f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
4 changes: 4 additions & 0 deletions zeta/experimental/triton/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from zeta.experimental.triton.activations.activations import (
leaky_relu_activation,
)
from zeta.experimental.triton.activations.activations import (
smooth_relu_activation,
)
from zeta.experimental.triton.activations.activations import softsign_activation
from zeta.experimental.triton.activations.activations import softplus_activation
from zeta.experimental.triton.activations.activations import sigmoid_activation
Expand All @@ -27,6 +30,7 @@
"relu_activation",
"relu6_activation",
"leaky_relu_activation",
"smooth_relu_activation",
"softsign_activation",
"softplus_activation",
"sigmoid_activation",
Expand Down
10 changes: 10 additions & 0 deletions zeta/experimental/triton/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def leaky_relu_activation(x: torch.Tensor, alpha: float = 0.2):
)


def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0):
# Make input tensor contiguous if needed
if not x.is_contiguous():
x = x.contiguous()

return apply_activation(
x, Functions.smooth_relu_activation_kernel, beta=beta
)


def softsign_activation(x: torch.Tensor):
return apply_activation(x, Functions.softsign_activation_kernel)

Expand Down
21 changes: 21 additions & 0 deletions zeta/experimental/triton/activations/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,27 @@ def leaky_relu_activation_kernel(
output = tl.maximum(x, alpha * x)
tl.store(output_ptr + offsets, output, mask=mask)

@staticmethod
@triton.jit
def smooth_relu_activation_kernel(
x_ptr, output_ptr, n_elements, beta, BLOCK_SIZE: tl.constexpr
):
"""
Convolution of ReLU with a box, transition region widens, the loss surface becomes smoother
"""
idx = tl.program_id(0)
block_st = idx * BLOCK_SIZE
offsets = block_st + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)

output = tl.where(x >= beta, x, 0.0)
output = tl.where(
tl.abs(x) <= beta, ((x + beta) * (x + beta) / (4.0 * beta), output)
)

tl.store(output_ptr + offsets, output, mask=mask)

@staticmethod
@triton.jit
def softsign_activation_kernel(
Expand Down

0 comments on commit 516c82f

Please sign in to comment.