Skip to content

Commit

Permalink
[FEAT][GatedXAttention][GatedMoECrossAttn]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jul 22, 2024
1 parent 5e51098 commit 4ff5d90
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 17 deletions.
11 changes: 8 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
[tool.poetry]
name = "zetascale"
version = "2.5.8"
version = "2.5.9"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
readme = "README.md"
homepage = "https://github.com/kyegomez/zeta"
keywords = ["Transformers", "zeta scale"]
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"]
classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9"
]

packages = [
{ include = "zeta" },
{ include = "zeta/**/*.py" },
Expand Down
4 changes: 3 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@
from zeta.nn.modules.simple_rnn import SimpleRNN
from zeta.nn.modules.cope import CoPE
from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention

from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -445,4 +445,6 @@
"SimpleRNN",
"CoPE",
"MultiLayerKeyValueAttention",
"GatedMoECrossAttn",
"GatedXAttention",
]
185 changes: 185 additions & 0 deletions zeta/nn/modules/evlm_xattn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from zeta.nn.attention.cross_attention import CrossAttention
from torch import nn, Tensor
from zeta.nn.modules.feedforward import FeedForward
from zeta.nn.modules.sparse_moe import NormalSparseMoE


class GatedXAttention(nn.Module):
"""
GatedXAttention module applies cross attention between text and image embeddings,
followed by activation functions and feed-forward neural network (FFN) layers.
Args:
dim (int): The input dimension of the text embeddings.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_head (int, optional): The dimension of each attention head. Defaults to 64.
dropout (float, optional): The dropout rate. Defaults to 0.1.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""

def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.1,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_head = dim_head

self.cross_attention = CrossAttention(
dim,
dim_head=dim_head,
heads=heads,
dropout=dropout,
*args,
**kwargs,
)

# ACT
self.act = nn.Tanh()

# FFN
self.ffn = FeedForward(
dim,
dim,
swish=True,
)

def forward(self, text: Tensor, img: Tensor, mask: Tensor = None) -> Tensor:
"""
Forward pass of the GatedXAttention module.
Args:
text (Tensor): The input text embeddings. Shape: (batch_size, sequence_length, dim).
img (Tensor): The input image embeddings.
mask (Tensor, optional): The attention mask. Defaults to None.
Returns:
Tensor: The output tensor after applying cross attention, activation functions, and FFN layers.
"""
# KV are image, Q is text
b, s, d = text.shape
residual = text

# Cross Attention
x = self.cross_attention(text, img, mask)

# Tanh
feeded = self.act(x)

# 2nd loop
out = feeded + residual

# Second residual
second_residual = out

# FFN
ffn_response = self.ffn(out)

# Tanded
out = self.act(ffn_response) + second_residual

return out


# x = torch.randn(1, 10, 512)
# img = torch.randn(1, 10, 512)

# model = GatedXAttention(512)

# out = model(x, img)
# print(out)


class GatedMoECrossAttn(nn.Module):
"""
GatedMoECrossAttn is a module that performs gated multi-expert cross attention on text and image inputs.
Args:
dim (int): The input dimension.
heads (int, optional): The number of attention heads. Defaults to 8.
dim_head (int, optional): The dimension of each attention head. Defaults to 64.
dropout (float, optional): The dropout rate. Defaults to 0.1.
experts (int, optional): The number of experts for the MoE. Defaults to 4.
Attributes:
dim (int): The input dimension.
heads (int): The number of attention heads.
dim_head (int): The dimension of each attention head.
cross_attention (CrossAttention): The cross attention module.
moe (NormalSparseMoE): The MoE module.
act (Tanh): The activation function.
Methods:
forward(text, img, mask=None): Performs forward pass of the module.
Returns:
Tensor: The output tensor after the forward pass.
"""

def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.1,
experts: int = 4,
*args,
**kwargs,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_head = dim_head

self.cross_attention = CrossAttention(
dim,
dim_head=dim_head,
heads=heads,
dropout=dropout,
*args,
**kwargs,
)

# MoE
self.moe = NormalSparseMoE(
dim,
experts,
)

self.act = nn.Tanh()

def forward(self, text: Tensor, img: Tensor, mask: Tensor = None) -> Tensor:
residual = text

# Cross Attention
attended = self.cross_attention(text, img, mask)

# Tanh
activated = self.act(attended) + residual

# Second Residual
second_residual = activated

# MoE
moe_response, loss = self.moe(activated)

# Add residual
out = moe_response + second_residual

return self.act(out)


# x = torch.randn(1, 10, 512)
# img = torch.randn(1, 10, 512)

# model = GatedMoECrossAttn(512)

# out = model(x, img)
# print(out.shape)
59 changes: 46 additions & 13 deletions zeta/nn/modules/multi_layer_key_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,29 @@


class MultiLayerKeyValueAttention(nn.Module):
"""
Multi-layer key-value attention module.
Args:
embed_size (int): The size of the input embeddings.
num_heads (int): The number of attention heads.
num_layers (int): The number of layers.
kv_layers (int): The number of key-value layers.
Attributes:
num_heads (int): The number of attention heads.
num_layers (int): The number of layers.
kv_layers (int): The number of key-value layers.
embed_size (int): The size of the input embeddings.
head_dim (int): The dimension of each attention head.
values (nn.ModuleList): List of value projection layers for each key-value layer.
keys (nn.ModuleList): List of key projection layers for each key-value layer.
queries (nn.ModuleList): List of query projection layers for each layer.
fc_out (nn.Linear): Output linear layer.
"""

def __init__(self, embed_size, num_heads, num_layers, kv_layers):
super(MultiLayerKeyValueAttention, self).__init__()
self.num_heads = num_heads
Expand Down Expand Up @@ -40,6 +63,18 @@ def __init__(self, embed_size, num_heads, num_layers, kv_layers):
self.fc_out = nn.Linear(embed_size, embed_size)

def forward(self, values, keys, queries):
"""
Forward pass of the multi-layer key-value attention module.
Args:
values (torch.Tensor): The values tensor of shape (N, value_len, embed_size).
keys (torch.Tensor): The keys tensor of shape (N, key_len, embed_size).
queries (torch.Tensor): The queries tensor of shape (N, query_len, embed_size).
Returns:
torch.Tensor: The output tensor of shape (N, query_len, embed_size).
"""
N = queries.shape[0]
value_len, key_len, query_len = (
values.shape[1],
Expand Down Expand Up @@ -78,18 +113,16 @@ def forward(self, values, keys, queries):
return out


# Example usage
embed_size = 256
num_heads = 8
num_layers = 4
kv_layers = 2 # Number of layers with their own KV heads
# # Example usage
# embed_size = 256
# num_heads = 8
# num_layers = 4
# kv_layers = 2 # Number of layers with their own KV heads

mlkv_attention = MultiLayerKeyValueAttention(
embed_size, num_heads, num_layers, kv_layers
)
values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10
keys = torch.rand(32, 10, embed_size)
queries = torch.rand(32, 10, embed_size)
# mlkv_attention = MultiLayerKeyValueAttention(embed_size, num_heads, num_layers, kv_layers)
# values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10
# keys = torch.rand(32, 10, embed_size)
# queries = torch.rand(32, 10, embed_size)

output = mlkv_attention(values, keys, queries)
print(output.shape)
# output = mlkv_attention(values, keys, queries)
# print(output.shape)
36 changes: 36 additions & 0 deletions zeta/nn/modules/sparse_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,31 @@ def forward(self, x, importance=None):


class NormalSparseMoE(nn.Module):
"""
NormalSparseMoE is a module that implements the Normal Sparse Mixture of Experts.
Args:
dim (int): The input dimension.
num_experts (int, optional): The number of experts in the mixture. Defaults to 16.
hidden_dim (int, optional): The dimension of the hidden layer in the experts. Defaults to None.
activation (torch.nn.Module, optional): The activation function to use in the experts. Defaults to torch.nn.ReLU.
second_policy_train (str, optional): The policy for selecting the second expert during training. Defaults to "random".
second_policy_eval (str, optional): The policy for selecting the second expert during evaluation. Defaults to "random".
second_threshold_train (float, optional): The threshold for selecting the second expert during training. Defaults to 0.2.
second_threshold_eval (float, optional): The threshold for selecting the second expert during evaluation. Defaults to 0.2.
capacity_factor_train (float, optional): The capacity factor for the gating mechanism during training. Defaults to 1.25.
capacity_factor_eval (float, optional): The capacity factor for the gating mechanism during evaluation. Defaults to 2.0.
loss_coef (float, optional): The coefficient for the loss term. Defaults to 1e-2.
experts (torch.nn.Module, optional): The module that implements the experts. Defaults to None.
Attributes:
num_experts (int): The number of experts in the mixture.
gate (Top2Gating): The gating mechanism for selecting the experts.
experts (torch.nn.Module): The module that implements the experts.
loss_coef (float): The coefficient for the loss term.
"""

def __init__(
self,
dim,
Expand Down Expand Up @@ -300,6 +325,17 @@ def __init__(
self.loss_coef = loss_coef

def forward(self, inputs, **kwargs):
"""
Forward pass of the NormalSparseMoE module.
Args:
inputs (torch.Tensor): The input tensor.
Returns:
output (torch.Tensor): The output tensor.
loss (torch.Tensor): The loss tensor.
"""
_b, _n, d, e = *inputs.shape, self.num_experts
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
expert_inputs = torch.einsum("bnd,bnec->ebcd", inputs, dispatch_tensor)
Expand Down

0 comments on commit 4ff5d90

Please sign in to comment.