Skip to content

Commit

Permalink
[FEATS][ ScaledSinusoidalEmbedding ] [ScaleNorm] [ReluSquared]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Mar 2, 2024
1 parent 48991aa commit 0d75ec6
Show file tree
Hide file tree
Showing 12 changed files with 129 additions and 11 deletions.
3 changes: 2 additions & 1 deletion docs/zeta/nn/attention/local.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Key terms:
## Class Definition

```python
class LocalAttention(nn.Module): ...
class LocalAttention(nn.Module):
...
```

### Parameters
Expand Down
6 changes: 4 additions & 2 deletions docs/zeta/nn/attention/mixture_of_attention_ar.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class MixtureOfAutoregressiveAttention(nn.Module):
prenorm: bool = True,
average_routed: bool = False,
**kwargs,
): ...
):
...
```

### Parameters:
Expand Down Expand Up @@ -62,7 +63,8 @@ def forward(
rotary_emb: Optional[torch.Tensor] = None,
num_routed_queries: Optional[int] = None,
num_routed_key_values: Optional[int] = None,
) -> torch.Tensor: ...
) -> torch.Tensor:
...
```

- `x` (torch.Tensor): Input tensor of shape `(batch_size, sequence_length, dim)`.
Expand Down
3 changes: 2 additions & 1 deletion docs/zeta/nn/biases/dynamic.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ Key concepts:

```python
class DynamicPositionBias(nn.Module):
def __init__(self, dim: int, heads: int): ...
def __init__(self, dim: int, heads: int):
...
```

### Parameters:
Expand Down
6 changes: 4 additions & 2 deletions docs/zeta/nn/embeddings/rope.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class RotaryEmbedding(nn.Module):
interpolation_factor=1.0,
base=10000,
base_rescale_factor=1.0,
): ...
):
...
```

### Parameters
Expand All @@ -29,7 +30,8 @@ class RotaryEmbedding(nn.Module):
### Method: `forward`

```python
def forward(self, seq_len, device): ...
def forward(self, seq_len, device):
...
```

#### Parameters
Expand Down
6 changes: 4 additions & 2 deletions docs/zeta/nn/modules/token_learner.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class TokenLearner(nn.Module):
ff_mult: int = 2,
num_output_tokens: int = 8,
num_layers: int = 2,
): ...
):
...
```

### Parameters:
Expand All @@ -43,7 +44,8 @@ The forward method of the `TokenLearner` class takes an input tensor `x` and per
### Method:

```python
def forward(self, x): ...
def forward(self, x):
...
```

### Parameters:
Expand Down
6 changes: 4 additions & 2 deletions docs/zeta/nn/modules/visual_expert.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class VisualExpert:
hidden_dim: int,
dropout: float,
heads: int,
): ...
):
...

def __call__(self, x: torch.Tensor): ...
def __call__(self, x: torch.Tensor):
...
```

### Parameters <a name="parameters"></a>
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"120": "https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installers/cuda_12.0.1_525.85.12_linux.run",
"121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run",
"122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run",
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.runbl",
}


Expand Down
5 changes: 5 additions & 0 deletions zeta/nn/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
rotate_every_two,
)
from zeta.nn.embeddings.yarn import YarnEmbedding
from zeta.nn.embeddings.scaled_sinusoidal_embeddings import (
ScaledSinusoidalEmbedding,
)


__all__ = [
"AbsolutePositionalEmbedding",
Expand Down Expand Up @@ -56,4 +60,5 @@
"fixed_pos_embedding",
"duplicate_interleave",
"VisionEmbedding",
"ScaledSinusoidalEmbedding",
]
47 changes: 47 additions & 0 deletions zeta/nn/embeddings/scaled_sinusoidal_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch import nn, Tensor, einsum

from zeta.utils.main import divisible_by


class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim: int, theta: int = 10000):
"""
Initializes a ScaledSinusoidalEmbedding module.
Args:
dim (int): The dimension of the embedding.
theta (int, optional): The scaling factor for the sinusoidal frequencies. Defaults to 10000.
"""
super().__init__()
assert divisible_by(dim, 2)
self.scale = nn.Parameter(torch.ones(1) * dim**-0.5)

half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta**-freq_seq
self.register_buffer("inv_freq", inv_freq, persistent=False)

def forward(self, x: Tensor, pos=None, seq_start_pos=None):
"""
Forward pass of the ScaledSinusoidalEmbedding module.
Args:
x (Tensor): The input tensor.
pos (Tensor, optional): The position tensor. Defaults to None.
seq_start_pos (Tensor, optional): The starting position tensor for sequences. Defaults to None.
Returns:
Tensor: The embedded tensor.
"""
sq, device = x.shape[1], x.device

if pos is not None:
pos = torch.arange(sq, device=device)

if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]

emb = einsum("i, j -> i j", pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb * self.scale
4 changes: 4 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@
from zeta.nn.modules.ws_conv2d import WSConv2d
from zeta.nn.modules.yolo import yolo
from zeta.nn.modules.palo_ldp import PaloLDP
from zeta.nn.modules.relu_squared import ReluSquared
from zeta.nn.modules.scale_norm import ScaleNorm

# from zeta.nn.modules.g_shard_moe import (
# Top1Gate,
Expand Down Expand Up @@ -386,4 +388,6 @@
"DynamicInputChannels",
"OutputDecoders",
"PaloLDP",
"ReluSquared",
"ScaleNorm",
]
17 changes: 17 additions & 0 deletions zeta/nn/modules/relu_squared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch import nn
import torch.nn.functional as F


class ReluSquared(nn.Module):
"""
Applies the ReLU activation function and squares the output.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying ReLU and squaring the result.
"""

def forward(self, x):
return F.relu(x) ** 2
35 changes: 35 additions & 0 deletions zeta/nn/modules/scale_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from torch import nn, Tensor


class ScaleNorm(nn.Module):
"""
Applies scale normalization to the input tensor along the last dimension.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-5.
"""

def __init__(
self,
dim: int,
eps: float = 1e-5,
):
super().__init__()
self.eps = eps

self.g = nn.Parameter(torch.ones(1) * (dim**-0.5))

def forward(self, x: Tensor):
"""
Applies scale normalization to the input tensor.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The scale-normalized tensor.
"""
norm = torch.norm(x, dim=-1, keepdim=True)
return x / norm.clamp(min=self.eps) + self.g

0 comments on commit 0d75ec6

Please sign in to comment.