Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 4, 2024
1 parent 4f48770 commit 03b27f7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 138 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ print(output.shape)
### `SwiGLU`
- Powers Transformer models
```python
from zeta.nn import SwiGLUStacked
import torch
from zeta.nn import SwiGLUStacked

x = torch.randn(5, 10)
swiglu = SwiGLUStacked(10, 20)
Expand All @@ -59,8 +59,8 @@ swiglu(x).shape
### ```RelativePositionBias```
- ```RelativePositionBias``` quantizes the distance between two positions into a certain number of buckets and then uses an embedding to get the relative position bias. This mechanism aids in the attention mechanism by providing biases based on relative positions between the query and key, rather than relying solely on their absolute positions.
```python
from zeta.nn import RelativePositionBias
import torch
from zeta.nn import RelativePositionBias

# Initialize the RelativePositionBias module
rel_pos_bias = RelativePositionBias()
Expand Down Expand Up @@ -380,7 +380,7 @@ print(output.shape) # Expected: torch.Size([1, 512])

```python
import torch
from zeta.nn.modules.simple_mamba import MambaBlock
from zeta.nn import MambaBlock

# Initialize Mamba
block = MambaBlock(dim=64, depth=1)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "1.3.9"
version = "1.4.0"
description = "Transformers at zeta scales"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
126 changes: 0 additions & 126 deletions zeta/models/multimodal_mamba.py

This file was deleted.

36 changes: 28 additions & 8 deletions zeta/nn/modules/simple_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import Tensor, nn

from zeta.nn.modules.rms_norm import RMSNorm

from zeta.utils import exists

class MambaBlock(nn.Module):
"""
Expand Down Expand Up @@ -76,7 +76,9 @@ def __init__(
)

# x_proj takes in `x` and outputs the input-specific Δ, B, C
self.x_proj = nn.Linear(dim_inner, dt_rank + self.d_state * 2, bias=False)
self.x_proj = nn.Linear(
dim_inner, dt_rank + self.d_state * 2, bias=False
)

# dt_proj projects Δ from dt_rank to d_in
self.dt_proj = nn.Linear(dt_rank, dim_inner, bias=True)
Expand Down Expand Up @@ -221,19 +223,21 @@ class Mamba(nn.Module):
expand (int): The expansion factor. Default is 2.
dt_rank (Union[int, str]): The rank of the temporal difference (Δ) tensor. Default is "auto".
d_conv (int): The dimension of the convolutional kernel. Default is 4.
Examples:
x = torch.randint(0, 16, (1, 64))
model = Mamba(16, 64, 5, 16)
out = model(x)
print(out)
"""

def __init__(
self,
vocab_size: int = None,
dim: int = None,
depth: int = 5,
d_state: int = 16,
img_dim: int = 64,
*args,
**kwargs,
):
Expand All @@ -242,14 +246,21 @@ def __init__(

self.embedding = nn.Embedding(vocab_size, dim)
self.norm_f = RMSNorm(dim)

self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight
self.mamba_layers = nn.ModuleList([
MambaBlock(dim=dim, depth=depth, d_state=d_state, *args, **kwargs) for _ in range(depth)
])
self.mamba_layers = nn.ModuleList(
[
MambaBlock(
dim=dim, depth=depth, d_state=d_state, *args, **kwargs
)
for _ in range(depth)
]
)

# Projection for img
self.img_proj = nn.Linear(img_dim, dim)

def forward(self, x: Tensor):
def forward(self, x: Tensor, context: Tensor = None,):
"""
Args:
x (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)
Expand All @@ -262,6 +273,13 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss
"""
x = self.embedding(x)

if exists(context):
# Project the image
projected_img = self.img_proj(context)

# Concatenate the image and text
x = torch.cat([x, projected_img], dim=1)

for layer in self.mamba_layers:
x = layer(self.norm_f(x)) + x
Expand All @@ -271,3 +289,5 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss

return logits



0 comments on commit 03b27f7

Please sign in to comment.