diff --git a/pyproject.toml b/pyproject.toml index 5a3db74e..c5435eb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,20 @@ [tool.poetry] name = "zetascale" -version = "2.5.8" +version = "2.5.9" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" readme = "README.md" homepage = "https://github.com/kyegomez/zeta" -keywords = ["Transformers", "zeta scale"] +keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] classifiers = [ - "Programming Language :: Python :: 3", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.9" ] + packages = [ { include = "zeta" }, { include = "zeta/**/*.py" }, diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index a5cd6e0c..442bab74 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -221,7 +221,7 @@ from zeta.nn.modules.simple_rnn import SimpleRNN from zeta.nn.modules.cope import CoPE from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention - +from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -445,4 +445,6 @@ "SimpleRNN", "CoPE", "MultiLayerKeyValueAttention", + "GatedMoECrossAttn", + "GatedXAttention", ] diff --git a/zeta/nn/modules/evlm_xattn.py b/zeta/nn/modules/evlm_xattn.py new file mode 100644 index 00000000..987e27a6 --- /dev/null +++ b/zeta/nn/modules/evlm_xattn.py @@ -0,0 +1,185 @@ +from zeta.nn.attention.cross_attention import CrossAttention +from torch import nn, Tensor +from zeta.nn.modules.feedforward import FeedForward +from zeta.nn.modules.sparse_moe import NormalSparseMoE + + +class GatedXAttention(nn.Module): + """ + GatedXAttention module applies cross attention between text and image embeddings, + followed by activation functions and feed-forward neural network (FFN) layers. + + Args: + dim (int): The input dimension of the text embeddings. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + dropout (float, optional): The dropout rate. Defaults to 0.1. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.1, + *args, + **kwargs, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.dim_head = dim_head + + self.cross_attention = CrossAttention( + dim, + dim_head=dim_head, + heads=heads, + dropout=dropout, + *args, + **kwargs, + ) + + # ACT + self.act = nn.Tanh() + + # FFN + self.ffn = FeedForward( + dim, + dim, + swish=True, + ) + + def forward(self, text: Tensor, img: Tensor, mask: Tensor = None) -> Tensor: + """ + Forward pass of the GatedXAttention module. + + Args: + text (Tensor): The input text embeddings. Shape: (batch_size, sequence_length, dim). + img (Tensor): The input image embeddings. + mask (Tensor, optional): The attention mask. Defaults to None. + + Returns: + Tensor: The output tensor after applying cross attention, activation functions, and FFN layers. + """ + # KV are image, Q is text + b, s, d = text.shape + residual = text + + # Cross Attention + x = self.cross_attention(text, img, mask) + + # Tanh + feeded = self.act(x) + + # 2nd loop + out = feeded + residual + + # Second residual + second_residual = out + + # FFN + ffn_response = self.ffn(out) + + # Tanded + out = self.act(ffn_response) + second_residual + + return out + + +# x = torch.randn(1, 10, 512) +# img = torch.randn(1, 10, 512) + +# model = GatedXAttention(512) + +# out = model(x, img) +# print(out) + + +class GatedMoECrossAttn(nn.Module): + """ + GatedMoECrossAttn is a module that performs gated multi-expert cross attention on text and image inputs. + + Args: + dim (int): The input dimension. + heads (int, optional): The number of attention heads. Defaults to 8. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + dropout (float, optional): The dropout rate. Defaults to 0.1. + experts (int, optional): The number of experts for the MoE. Defaults to 4. + + Attributes: + dim (int): The input dimension. + heads (int): The number of attention heads. + dim_head (int): The dimension of each attention head. + cross_attention (CrossAttention): The cross attention module. + moe (NormalSparseMoE): The MoE module. + act (Tanh): The activation function. + + Methods: + forward(text, img, mask=None): Performs forward pass of the module. + + Returns: + Tensor: The output tensor after the forward pass. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.1, + experts: int = 4, + *args, + **kwargs, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.dim_head = dim_head + + self.cross_attention = CrossAttention( + dim, + dim_head=dim_head, + heads=heads, + dropout=dropout, + *args, + **kwargs, + ) + + # MoE + self.moe = NormalSparseMoE( + dim, + experts, + ) + + self.act = nn.Tanh() + + def forward(self, text: Tensor, img: Tensor, mask: Tensor = None) -> Tensor: + residual = text + + # Cross Attention + attended = self.cross_attention(text, img, mask) + + # Tanh + activated = self.act(attended) + residual + + # Second Residual + second_residual = activated + + # MoE + moe_response, loss = self.moe(activated) + + # Add residual + out = moe_response + second_residual + + return self.act(out) + + +# x = torch.randn(1, 10, 512) +# img = torch.randn(1, 10, 512) + +# model = GatedMoECrossAttn(512) + +# out = model(x, img) +# print(out.shape) diff --git a/zeta/nn/modules/multi_layer_key_cache.py b/zeta/nn/modules/multi_layer_key_cache.py index 08f9e1ea..b9df0a9f 100644 --- a/zeta/nn/modules/multi_layer_key_cache.py +++ b/zeta/nn/modules/multi_layer_key_cache.py @@ -3,6 +3,29 @@ class MultiLayerKeyValueAttention(nn.Module): + """ + Multi-layer key-value attention module. + + Args: + embed_size (int): The size of the input embeddings. + num_heads (int): The number of attention heads. + num_layers (int): The number of layers. + kv_layers (int): The number of key-value layers. + + Attributes: + num_heads (int): The number of attention heads. + num_layers (int): The number of layers. + kv_layers (int): The number of key-value layers. + embed_size (int): The size of the input embeddings. + head_dim (int): The dimension of each attention head. + + values (nn.ModuleList): List of value projection layers for each key-value layer. + keys (nn.ModuleList): List of key projection layers for each key-value layer. + queries (nn.ModuleList): List of query projection layers for each layer. + fc_out (nn.Linear): Output linear layer. + + """ + def __init__(self, embed_size, num_heads, num_layers, kv_layers): super(MultiLayerKeyValueAttention, self).__init__() self.num_heads = num_heads @@ -40,6 +63,18 @@ def __init__(self, embed_size, num_heads, num_layers, kv_layers): self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, values, keys, queries): + """ + Forward pass of the multi-layer key-value attention module. + + Args: + values (torch.Tensor): The values tensor of shape (N, value_len, embed_size). + keys (torch.Tensor): The keys tensor of shape (N, key_len, embed_size). + queries (torch.Tensor): The queries tensor of shape (N, query_len, embed_size). + + Returns: + torch.Tensor: The output tensor of shape (N, query_len, embed_size). + + """ N = queries.shape[0] value_len, key_len, query_len = ( values.shape[1], @@ -78,18 +113,16 @@ def forward(self, values, keys, queries): return out -# Example usage -embed_size = 256 -num_heads = 8 -num_layers = 4 -kv_layers = 2 # Number of layers with their own KV heads +# # Example usage +# embed_size = 256 +# num_heads = 8 +# num_layers = 4 +# kv_layers = 2 # Number of layers with their own KV heads -mlkv_attention = MultiLayerKeyValueAttention( - embed_size, num_heads, num_layers, kv_layers -) -values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10 -keys = torch.rand(32, 10, embed_size) -queries = torch.rand(32, 10, embed_size) +# mlkv_attention = MultiLayerKeyValueAttention(embed_size, num_heads, num_layers, kv_layers) +# values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10 +# keys = torch.rand(32, 10, embed_size) +# queries = torch.rand(32, 10, embed_size) -output = mlkv_attention(values, keys, queries) -print(output.shape) +# output = mlkv_attention(values, keys, queries) +# print(output.shape) diff --git a/zeta/nn/modules/sparse_moe.py b/zeta/nn/modules/sparse_moe.py index e0652244..85dd96c1 100644 --- a/zeta/nn/modules/sparse_moe.py +++ b/zeta/nn/modules/sparse_moe.py @@ -260,6 +260,31 @@ def forward(self, x, importance=None): class NormalSparseMoE(nn.Module): + """ + NormalSparseMoE is a module that implements the Normal Sparse Mixture of Experts. + + Args: + dim (int): The input dimension. + num_experts (int, optional): The number of experts in the mixture. Defaults to 16. + hidden_dim (int, optional): The dimension of the hidden layer in the experts. Defaults to None. + activation (torch.nn.Module, optional): The activation function to use in the experts. Defaults to torch.nn.ReLU. + second_policy_train (str, optional): The policy for selecting the second expert during training. Defaults to "random". + second_policy_eval (str, optional): The policy for selecting the second expert during evaluation. Defaults to "random". + second_threshold_train (float, optional): The threshold for selecting the second expert during training. Defaults to 0.2. + second_threshold_eval (float, optional): The threshold for selecting the second expert during evaluation. Defaults to 0.2. + capacity_factor_train (float, optional): The capacity factor for the gating mechanism during training. Defaults to 1.25. + capacity_factor_eval (float, optional): The capacity factor for the gating mechanism during evaluation. Defaults to 2.0. + loss_coef (float, optional): The coefficient for the loss term. Defaults to 1e-2. + experts (torch.nn.Module, optional): The module that implements the experts. Defaults to None. + + Attributes: + num_experts (int): The number of experts in the mixture. + gate (Top2Gating): The gating mechanism for selecting the experts. + experts (torch.nn.Module): The module that implements the experts. + loss_coef (float): The coefficient for the loss term. + + """ + def __init__( self, dim, @@ -300,6 +325,17 @@ def __init__( self.loss_coef = loss_coef def forward(self, inputs, **kwargs): + """ + Forward pass of the NormalSparseMoE module. + + Args: + inputs (torch.Tensor): The input tensor. + + Returns: + output (torch.Tensor): The output tensor. + loss (torch.Tensor): The loss tensor. + + """ _b, _n, d, e = *inputs.shape, self.num_experts dispatch_tensor, combine_tensor, loss = self.gate(inputs) expert_inputs = torch.einsum("bnd,bnec->ebcd", inputs, dispatch_tensor)