Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 10, 2024
1 parent bcfd999 commit 34005a1
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "1.6.3"
version = "1.6.5"
description = "Transformers at zeta scales"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -44,6 +44,7 @@ tqdm = "4.66.1"
rich = "13.7.0"
argparse = "^1.4.0"
skypilot = "0.4.1"
numexpr = "*"


[build-system]
Expand Down
8 changes: 6 additions & 2 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from zeta.nn.modules.simple_feedforward import SimpleFeedForward
from zeta.nn.modules.simple_res_block import SimpleResBlock
from zeta.nn.modules.skipconnection import SkipConnection
from zeta.nn.modules.spacial_transformer import SpacialTransformer
from zeta.nn.modules.spacial_transformer import SpatialTransformer
from zeta.nn.modules.subln import SubLN
from zeta.nn.modules.super_resolution import SuperResolutionNet
from zeta.nn.modules.time_up_sample import TimeUpSample2x
Expand Down Expand Up @@ -98,6 +98,8 @@
from zeta.nn.modules.proj_then_softmax import FusedProjSoftmax
from zeta.nn.modules.top_n_gating import TopNGating
from zeta.nn.modules.moe_router import MoERouter
from zeta.nn.modules.perceiver_layer import PerceiverLayer
from zeta.nn.modules.u_mamba import UMambaBlock

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -133,7 +135,7 @@
"RNNL",
"ShuffleNet",
"simple_attention",
"SpacialTransformer",
"SpatialTransformer",
"SubLN",
"SuperResolutionNet",
"TokenLearner",
Expand Down Expand Up @@ -209,4 +211,6 @@
"FusedProjSoftmax",
"TopNGating",
"MoERouter",
"PerceiverLayer",
"UMambaBlock"
]
118 changes: 118 additions & 0 deletions zeta/nn/modules/perceiver_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Optional

import torch
from torch import Tensor, nn

from zeta.nn.attention.cross_attention import CrossAttention
from zeta.nn.attention.multiquery_attention import MultiQueryAttention


class PerceiverLayer(nn.Module):
"""
Perceiver Layer, this layer has a self attn that takes in q then ->
sends the output into the q of the cross attention where the cross attn
takes in k and v. The output of the cross attn is then sent into a
feed forward layer.
Args:
dim: dimension of the input tensor
heads: number of heads
depth: number of layers
dim_head: dimension of each head
dropout: dropout rate
ff_dropout: feed forward dropout rate
ff_mult: feed forward multiplier
Examples::
>>> q = torch.randn(1, 32, 512)
>>> k = torch.randn(1, 32, 512)
>>> v = torch.randn(1, 32, 512)
>>> layer = PerceiverLayer(512, 8, 6, 64)
>>> print(layer(q, k, v).shape)
torch.Size([1, 32, 512])
"""

def __init__(
self,
dim: int,
heads: int,
depth: int,
dim_head: int = 64,
dropout: float = 0.1,
ff_dropout: float = 0.1,
ff_mult: int = 4,
):
super().__init__()
self.dim = dim
self.heads = heads
self.depth = depth
self.dim_head = dim_head
self.dropout = dropout
self.ff_dropout = ff_dropout
self.ff_mult = ff_mult

# Initialize layers for MultiQueryAttention, CrossAttention, and Feed Forward
self.self_attn = MultiQueryAttention(
dim,
heads,
# qk_ln=True,
)

# CrossAttention initialization
self.cross_attn = CrossAttention(
dim,
context_dim=dim,
dim_head=dim_head,
heads=heads,
dropout=dropout,
)

# Feed Forward initialization
self.ffn = nn.Sequential(
nn.Linear(dim, dim * ff_mult),
nn.GELU(),
nn.Dropout(ff_dropout),
nn.Linear(dim * ff_mult, dim),
nn.Dropout(ff_dropout),
)

# Projection layers for x to -> q, k, v
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)

def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
mask: Optional[Tensor] = None,
):
"""
Args:
q: query tensor
k: key tensor
v: value tensor
mask: mask tensor
Shape:
q: (batch_size, seq_len_q, dim)
k: (batch_size, seq_len_k, dim)
v: (batch_size, seq_len_v, dim)
mask: (batch_size, seq_len_q, seq_len_k)
"""
q, _, _ = self.self_attn(q)

# Concatenate k and v
kv = torch.concat((k, v), dim=1)

# Send q, k, v into cross attention with q as the context
x = self.cross_attn(kv, q)

# Apply feed forward layer to output of cross attention
x = self.ffn(x)

# Return output
return x
6 changes: 3 additions & 3 deletions zeta/nn/modules/spacial_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
import torch.nn.functional as F


class SpacialTransformer(nn.Module):
class SpatialTransformer(nn.Module):
"""
Spacial Transformer Network
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
Usage:
>>> stn = SpacialTransformer()
>>> stn = SpatialTransformer()
>>> stn.stn(x)
"""

def __init__(self):
super(SpacialTransformer, self).__init__()
super(SpatialTransformer, self).__init__()

# spatial transformer localization-network
linear = nn.Linear(32, 3 * 2)
Expand Down
Loading

0 comments on commit 34005a1

Please sign in to comment.