Skip to content

Commit

Permalink
[CLEANUP][Sky]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 15, 2024
2 parents e73dae3 + 516c82f commit be1c7e1
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Zeta-specific
experimental_tests.py

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ beartype>=0.15.0,<0.16.0
vector-quantize-pytorch>=1.12.0,<1.13.0
scipy>=1.9.3,<1.10.0
loguru
rich>=13.7.0,<13.8.0
tiktoken>=0.6.0,<0.7.0
transformers>=4.36.0,<4.37.0
tqdm>=4.66.2,<4.67.0
rich==13.7.1
tiktoken==0.6.0
transformers==4.36.0
tqdm==4.66.2
mkdocs
mkdocs-material
mkdocs-glightbox
Expand Down
1 change: 1 addition & 0 deletions zeta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# from zeta.tokenizers import * # noqa: F403, E402
from zeta.training import * # noqa: F403, E402
from zeta.utils import * # noqa: F403, E402
from zeta.experimental import * # noqa: F403, E402
Empty file added zeta/experimental/__init__.py
Empty file.
Empty file.
43 changes: 43 additions & 0 deletions zeta/experimental/triton/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from zeta.experimental.triton.activations.activations import tanh_activation
from zeta.experimental.triton.activations.activations import (
hard_tanh_activation,
)
from zeta.experimental.triton.activations.activations import relu_activation
from zeta.experimental.triton.activations.activations import relu6_activation
from zeta.experimental.triton.activations.activations import (
leaky_relu_activation,
)
from zeta.experimental.triton.activations.activations import (
smooth_relu_activation,
)
from zeta.experimental.triton.activations.activations import softsign_activation
from zeta.experimental.triton.activations.activations import softplus_activation
from zeta.experimental.triton.activations.activations import sigmoid_activation
from zeta.experimental.triton.activations.activations import (
hard_sigmoid_activation,
)
from zeta.experimental.triton.activations.activations import silu_activation
from zeta.experimental.triton.activations.activations import (
hard_silu_activation,
)
from zeta.experimental.triton.activations.activations import softmax_activation
from zeta.experimental.triton.activations.activations import gelu_activation
from zeta.experimental.triton.activations.activations import swiglu_activation

__all__ = [
"tanh_activation",
"hard_tanh_activation",
"relu_activation",
"relu6_activation",
"leaky_relu_activation",
"smooth_relu_activation",
"softsign_activation",
"softplus_activation",
"sigmoid_activation",
"hard_sigmoid_activation",
"silu_activation",
"hard_silu_activation",
"softmax_activation",
"gelu_activation",
"swiglu_activation",
]
96 changes: 96 additions & 0 deletions zeta/experimental/triton/activations/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import triton
import triton.language as tl

from typing import Callable
from activations.functions import Functions

BLOCK_SIZE = 1024


def apply_activation(
x: torch.Tensor, activation_fn: Callable[..., torch.Tensor], *args, **kwargs
):
if not x.is_cuda:
raise ValueError("Input tensor must be on CUDA.")

output = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)

activation_args = [x, output] + list(args)

if "n_elements" not in kwargs:
kwargs["n_elements"] = n_elements

activation_fn[grid](*activation_args, BLOCK_SIZE=1024, **kwargs)

return output


def tanh_activation(x: torch.Tensor):
return apply_activation(x, Functions.tanh_activation_kernel)


def hard_tanh_activation(x: torch.Tensor):
return apply_activation(x, Functions.hard_tanh_activation_kernel)


def relu_activation(x: torch.Tensor):
return apply_activation(x, Functions.relu_activation_kernel)


def relu6_activation(x: torch.Tensor):
return apply_activation(x, Functions.relu6_activation_kernel)


def leaky_relu_activation(x: torch.Tensor, alpha: float = 0.2):
return apply_activation(
x, Functions.leaky_relu_activation_kernel, alpha=alpha
)


def smooth_relu_activation(x: torch.Tensor, beta: float = 2.0):
# Make input tensor contiguous if needed
if not x.is_contiguous():
x = x.contiguous()

return apply_activation(
x, Functions.smooth_relu_activation_kernel, beta=beta
)


def softsign_activation(x: torch.Tensor):
return apply_activation(x, Functions.softsign_activation_kernel)


def softplus_activation(x: torch.Tensor):
return apply_activation(x, Functions.softplus_activation_kernel)


def sigmoid_activation(x: torch.Tensor):
return apply_activation(x, Functions.sigmoid_activation_kernel)


def hard_sigmoid_activation(x: torch.Tensor):
return apply_activation(x, Functions.hard_sigmoid_activation_kernel)


def silu_activation(x: torch.Tensor):
return apply_activation(x, Functions.silu_activation_kernel)


def hard_silu_activation(x: torch.Tensor):
return apply_activation(x, Functions.hard_silu_activation_kernel)


def softmax_activation(x: torch.Tensor):
return apply_activation(x, Functions.softmax_activation_kernel)


def gelu_activation(x: torch.Tensor, approximate: bool = True):
return apply_activation(x, Functions.gelu_activation_kernel, approximate)


def swiglu_activation(x: torch.Tensor):
return apply_activation(x, Functions.swiglu_activation_kernel)
Loading

0 comments on commit be1c7e1

Please sign in to comment.