From 161e48b85eb8793310040c8355e5f9b79237a174 Mon Sep 17 00:00:00 2001 From: dogukan uraz tuna <156364766+simudt@users.noreply.github.com> Date: Thu, 11 Apr 2024 22:25:56 +0300 Subject: [PATCH] init func & activation --- .../triton/activations/__init__.py | 4 ++++ .../triton/activations/activations.py | 10 +++++++++ .../triton/activations/functions.py | 21 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/zeta/experimental/triton/activations/__init__.py b/zeta/experimental/triton/activations/__init__.py index 6ec4e4d0..e49bb32d 100644 --- a/zeta/experimental/triton/activations/__init__.py +++ b/zeta/experimental/triton/activations/__init__.py @@ -7,6 +7,9 @@ from zeta.experimental.triton.activations.activations import ( leaky_relu_activation, ) +from zeta.experimental.triton.activations.activations import ( + smooth_relu_activation, +) from zeta.experimental.triton.activations.activations import softsign_activation from zeta.experimental.triton.activations.activations import softplus_activation from zeta.experimental.triton.activations.activations import sigmoid_activation @@ -27,6 +30,7 @@ "relu_activation", "relu6_activation", "leaky_relu_activation", + "smooth_relu_activation", "softsign_activation", "softplus_activation", "sigmoid_activation", diff --git a/zeta/experimental/triton/activations/activations.py b/zeta/experimental/triton/activations/activations.py index 4351696b..fbfa11d5 100644 --- a/zeta/experimental/triton/activations/activations.py +++ b/zeta/experimental/triton/activations/activations.py @@ -50,6 +50,16 @@ def leaky_relu_activation(x: torch.Tensor, alpha: float = 0.2): ) +def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0): + # Make input tensor contiguous if needed + if not x.is_contiguous(): + x = x.contiguous() + + return apply_activation( + x, Functions.smooth_relu_activation_kernel, beta=beta + ) + + def softsign_activation(x: torch.Tensor): return apply_activation(x, Functions.softsign_activation_kernel) diff --git a/zeta/experimental/triton/activations/functions.py b/zeta/experimental/triton/activations/functions.py index 2e0621e1..9fadc5d6 100644 --- a/zeta/experimental/triton/activations/functions.py +++ b/zeta/experimental/triton/activations/functions.py @@ -93,6 +93,27 @@ def leaky_relu_activation_kernel( output = tl.maximum(x, alpha * x) tl.store(output_ptr + offsets, output, mask=mask) + @staticmethod + @triton.jit + def smooth_relu_activation_kernel( + x_ptr, output_ptr, n_elements, beta, BLOCK_SIZE: tl.constexpr + ): + """ + Convolution of ReLU with a box, transition region widens, the loss surface becomes smoother + """ + idx = tl.program_id(0) + block_st = idx * BLOCK_SIZE + offsets = block_st + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + + output = tl.where(x >= beta, x, 0.0) + output = tl.where( + tl.abs(x) <= beta, ((x + beta) * (x + beta) / (4.0 * beta), output) + ) + + tl.store(output_ptr + offsets, output, mask=mask) + @staticmethod @triton.jit def softsign_activation_kernel(