diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..5c3c6c2f Binary files /dev/null and b/.DS_Store differ diff --git a/docs/zeta/quant/bitlinear.md b/docs/zeta/quant/bitlinear.md index 482f74b9..02c3daa4 100644 --- a/docs/zeta/quant/bitlinear.md +++ b/docs/zeta/quant/bitlinear.md @@ -61,7 +61,7 @@ Performs the forward pass of the `BitLinear` module. ```python import torch -from zeta.quant import BitLinear +from zeta.nn.quant import BitLinear # Initialize the BitLinear module linear = BitLinear(10, 20) @@ -82,7 +82,7 @@ print(output.size()) # torch.Size([128, 20]) ```python import torch -from zeta.quant import BitLinear +from zeta.nn.quant import BitLinear # Initialize the BitLinear module with 2 groups linear = BitLinear(10, 20, groups=2) @@ -103,7 +103,7 @@ print(output.size()) # torch.Size([128, 20]) import torch from torch import nn -from zeta.quant import BitLinear +from zeta.nn.quant import BitLinear class MyModel(nn.Module): diff --git a/docs/zeta/quant/qlora.md b/docs/zeta/quant/qlora.md index 087bed04..417c3930 100644 --- a/docs/zeta/quant/qlora.md +++ b/docs/zeta/quant/qlora.md @@ -54,7 +54,7 @@ To instantiate a QloraLinear layer: ```python import torch.nn as nn -from zeta.quant.qlora import QloraLinear +from zeta.nn.quant.qlora import QloraLinear in_features = 20 out_features = 30 diff --git a/docs/zeta/quant/quik.md b/docs/zeta/quant/quik.md index 16c898bc..e7cf8593 100644 --- a/docs/zeta/quant/quik.md +++ b/docs/zeta/quant/quik.md @@ -124,7 +124,7 @@ In this example, we'll initialize the QUIK layer. ```python import torch -from zeta.quant import QUIK +from zeta.nn.quant import QUIK # Initialize the QUIK module quik = QUIK(in_features=784, out_features=10) diff --git a/pyproject.toml b/pyproject.toml index 7cc1063d..d7a58b36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,51 @@ [tool.poetry] name = "zetascale" -version = "2.5.9" -description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" +version = "2.6.1" +description = "Rapidly Build, Optimize, and Train SOTA AI Models" authors = ["Zeta Team "] license = "MIT" readme = "README.md" homepage = "https://github.com/kyegomez/zeta" -keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] +keywords = [ + "artificial intelligence", + "deep learning", + "optimizers", + "Prompt Engineering", + "swarms", + "agents", + "llms", + "transformers", + "multi-agent", + "swarms of agents", + "Enterprise-Grade Agents", + "Production-Grade Agents", + "Agents", + "Multi-Grade-Agents", + "Swarms", + "Transformers", + "LLMs", + "Prompt Engineering", + "Agents", + "Generative Agents", + "Generative AI", + "Agent Marketplace", + "Agent Store", + "LSTMS", + "GRUs", + "RNNs", + "CNNs", + "MLPs", + "DNNs", +] classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Topic :: Scientific/Engineering :: Artificial Intelligence", "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3.9" + "Programming Language :: Python :: 3.10", ] + packages = [ { include = "zeta" }, { include = "zeta/**/*.py" }, @@ -65,7 +96,7 @@ target-version = ['py38'] preview = true -[tool.poetry.scripts] -zeta = 'zeta.cli.main:main' +# [tool.poetry.scripts] +# zeta = 'zeta.cli.main:main' diff --git a/tests/quant/test_bitlinear.py b/tests/quant/test_bitlinear.py index c64c8602..00518a2d 100644 --- a/tests/quant/test_bitlinear.py +++ b/tests/quant/test_bitlinear.py @@ -1,7 +1,7 @@ import pytest import torch -from zeta.quant.bitlinear import BitLinear, absmax_quantize +from zeta.nn.quant.bitlinear import BitLinear, absmax_quantize def test_bitlinear_reset_parameters(): diff --git a/tests/quant/test_half_bit_linear.py b/tests/quant/test_half_bit_linear.py index 403bf567..3ca3b9c7 100644 --- a/tests/quant/test_half_bit_linear.py +++ b/tests/quant/test_half_bit_linear.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from zeta.quant.half_bit_linear import HalfBitLinear +from zeta.nn.quant.half_bit_linear import HalfBitLinear def test_half_bit_linear_init(): diff --git a/tests/quant/test_lfq.py b/tests/quant/test_lfq.py index af31c9fd..99fd4bd2 100644 --- a/tests/quant/test_lfq.py +++ b/tests/quant/test_lfq.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from zeta.quant.lfq import LFQ +from zeta.nn.quant.lfq import LFQ def test_lfg_init(): diff --git a/tests/quant/test_niva.py b/tests/quant/test_niva.py index 71bee69a..36882fe8 100644 --- a/tests/quant/test_niva.py +++ b/tests/quant/test_niva.py @@ -5,7 +5,7 @@ import torch.nn as nn from zeta.nn import QFTSPEmbedding -from zeta.quant.niva import niva +from zeta.nn.quant.niva import niva def test_niva_model_type(): diff --git a/tests/quant/test_qlora.py b/tests/quant/test_qlora.py index e6a8bdf7..bd90b4ee 100644 --- a/tests/quant/test_qlora.py +++ b/tests/quant/test_qlora.py @@ -2,7 +2,7 @@ import torch from torch.testing import assert_allclose -from zeta.quant.qlora import QloraLinear +from zeta.nn.quant.qlora import QloraLinear # Sample instantiation values in_features = 20 diff --git a/tests/quant/test_quik.py b/tests/quant/test_quik.py index 8784127b..7c8142eb 100644 --- a/tests/quant/test_quik.py +++ b/tests/quant/test_quik.py @@ -1,6 +1,6 @@ import torch -from zeta.quant.quick import QUIK +from zeta.nn.quant.quick import QUIK def test_quik_initialization(): diff --git a/tests/quant/test_resudual_vq.py b/tests/quant/test_resudual_vq.py index f46cff0f..c2fe98b7 100644 --- a/tests/quant/test_resudual_vq.py +++ b/tests/quant/test_resudual_vq.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from zeta.quant.residual_vq import ResidualVectorQuantizer +from zeta.nn.quant.residual_vq import ResidualVectorQuantizer def test_residual_vector_quantizer_init(): diff --git a/tests/utils/test_absmax.py b/tests/utils/test_absmax.py index b40adef7..2690c547 100644 --- a/tests/utils/test_absmax.py +++ b/tests/utils/test_absmax.py @@ -1,6 +1,6 @@ import torch -from zeta.quant.absmax import absmax_quantize +from zeta.nn.quant.absmax import absmax_quantize def test_absmax_quantize_default_bits(): diff --git a/todo/dit_block.py b/todo/dit_block.py new file mode 100644 index 00000000..2bb31109 --- /dev/null +++ b/todo/dit_block.py @@ -0,0 +1,107 @@ +import torch +from torch import nn, Tensor +from zeta.nn.attention.multiquery_attention import MultiQueryAttention +from zeta.nn.modules.feedforward import FeedForward +from zeta.nn.modules.scale import Scale +from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm + + +def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class AdaLN(nn.Module): + """ + Adaptive Layer Normalization (AdaLN) module. + + Args: + dim (int): The input dimension. + eps (float): A small value added to the denominator for numerical stability. + scale (int): The scale factor for the linear layer. + bias (bool): Whether to include a bias term in the linear layer. + """ + + def __init__( + self, + dim: int = None, + eps: float = 1e-5, + scale: int = 4, + bias: bool = True, + ): + super().__init__() + self.eps = eps + self.scale = scale + self.bias = bias + + self.norm = nn.Sequential( + nn.SiLU(), + nn.Linear(dim, dim * scale, bias=bias), + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the AdaLN module. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The normalized output tensor. + """ + return self.norm(x) + + +class DitBlock(nn.Module): + def __init__( + self, + dim: int, + dim_head: int = None, + dropout: float = 0.1, + heads: int = 8, + ): + super().__init__() + self.dim = dim + self.dim_head = dim_head + self.dropout = dropout + self.heads = heads + + # Attention + self.attn = MultiQueryAttention( + dim, + heads, + ) + + # FFN + self.input_ffn = FeedForward(dim, dim, 4, swish=True) + + # Conditioning mlp + self.conditioning_mlp = FeedForward(dim, dim, 4, swish=True) + + # Shift + # self.shift_op = ShiftTokens() + + # Norm + self.norm = AdaptiveLayerNorm(dim) + + def forward(self, x: Tensor, conditioning: Tensor) -> Tensor: + + # Norm + self.norm(x) + + # Scale + # scaled = modulate( + # x, + # normalize, + # normalize + # ) + + # return scaled + scaled = Scale(fn=self.norm)(x) + return scaled + + +input = torch.randn(1, 10, 512) +conditioning = torch.randn(1, 10, 512) +dit_block = DitBlock(512) +output = dit_block(input, conditioning) +print(output.shape) diff --git a/todo/hyper_attention.py b/todo/hyper_attention.py new file mode 100644 index 00000000..dce7bddb --- /dev/null +++ b/todo/hyper_attention.py @@ -0,0 +1,173 @@ +from torch import nn, Tensor +from typing import Optional +from zeta import FeedForward +from zeta import MultiModalCrossAttention, MultiQueryAttention +from zeta.nn.embeddings.mi_rope import MIRoPE +import torch +from torch.nn import functional as F + + +def exists(val): + return val is not None + + +class AdaptiveGating(nn.Module): + def __init__(self, hidden_dim: int): + """ + Initializes an instance of the AdaptiveGating class. + + Args: + hidden_dim (int): The dimension of the hidden state. + + """ + super().__init__() + self.hidden_dim = hidden_dim + self.sigmoid = nn.Sigmoid() + + def forward( + self, + hat_text: Tensor, + bar_text: Tensor, + ) -> Tensor: + """ + Performs the forward pass of the AdaptiveGating module. + + Args: + hat_text (Tensor): The input tensor representing the hat text. + bar_text (Tensor): The input tensor representing the bar text. + + Returns: + Tensor: The fused hidden state tensor. + + """ + g = self.sigmoid(hat_text) + + # Step 2 + h_fused = bar_text * g + hat_text * (1 - g) + + return h_fused + + +class HyperAttentionmPLUGOwlBlock(nn.Module): + """ + HyperAttentionmPLUGOwlBlock is a module that performs hyper attention between image and text inputs. + + Args: + dim (int): The dimension of the input. + heads (int): The number of attention heads. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + mi_rope_on (bool, optional): Whether to use mutual information rope. Defaults to True. + max_seq_len (int, optional): The maximum sequence length. Defaults to 100. + """ + + def __init__( + self, + dim: int, + heads: int, + dim_head: int = 64, + mi_rope_on: bool = True, + max_seq_len: int = 100, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.dim_head = dim_head + self.mi_rope_on = mi_rope_on + self.max_seq_len = max_seq_len + + self.norm = nn.LayerNorm(dim) + self.ffn = FeedForward( + dim, + dim, + 4, + swiglu=True, + ) + + # Projections + self.w_kv_img = nn.Linear(dim, dim * 2) + self.w_q_text = nn.Linear(dim, dim) + + # Attention + self.attn = MultiModalCrossAttention( + dim, + heads, + context_dim=dim, + qk=True, + post_attn_norm=True, + ) + + self.attn_op = MultiQueryAttention( + dim, + heads, + ) + + # Rotary Position Embedding + self.rotary_embeddings = MIRoPE(dim) + self.proj = nn.Linear(dim, dim) + self.text_proj = nn.Linear(dim, dim) + self.final_proj = nn.Linear(dim, dim) + self.gate = AdaptiveGating(dim) + + def forward(self, img: Tensor, text: Tensor, mask: Optional[Tensor] = None): + """ + Forward pass of the HyperAttentionmPLUGOwlBlock module. + + Args: + img (Tensor): The input image tensor. + text (Tensor): The input text tensor. + mask (Optional[Tensor], optional): The attention mask tensor. Defaults to None. + + Returns: + Tensor: The output tensor. + """ + n = img.shape[1] + img = self.norm(img) + text = self.norm(text) + + # Rotary Position Embedding + # positions, scale = self.get_rotary_embedding(n, img.device) + + # Apply rotary position embedding + + w_img_k = self.proj(img) + w_img_v = self.proj(img) + + w_q_text = self.text_proj(text) + w_k_text = self.text_proj(w_q_text) + w_v_text = self.text_proj(w_q_text) + + # Attn op + with torch.backends.cuda.sdp_kernel(enable_math=True): + img_attn = F.scaled_dot_product_attention( + w_q_text, + w_img_k, + w_img_v, + ) + + with torch.backends.cuda.sdp_kernel(enable_math=True): + text_attn = F.scaled_dot_product_attention( + w_q_text, + w_k_text, + w_v_text, + ) + + output_gate = self.gate(img_attn, text_attn) + + return self.final_proj(output_gate) + + def get_rotary_embedding(self, n, device): + if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n: + return self.pos_emb[:n], self.pos_emb_scale[:n] + + pos_emb, scale = self.rotary_emb(n, device=device) + self.register_buffer("pos_emb", pos_emb, persistent=False) + self.register_buffer("pos_emb_scale", scale, persistent=False) + return pos_emb, scale + + +input = torch.randn(1, 10, 512) + +conditioning = torch.randn(1, 10, 512) +model = HyperAttentionmPLUGOwlBlock(512, 8) +output = model(input, conditioning) +print(output.shape) diff --git a/multi_head_latent_attention.py b/todo/multi_head_latent_attention.py similarity index 100% rename from multi_head_latent_attention.py rename to todo/multi_head_latent_attention.py diff --git a/zeta/__init__.py b/zeta/__init__.py index dc752fd4..ffe526ec 100644 --- a/zeta/__init__.py +++ b/zeta/__init__.py @@ -7,7 +7,7 @@ from zeta.nn import * # noqa: F403, E402 from zeta.ops import * # noqa: F403, E402 from zeta.optim import * # noqa: F403, E402 -from zeta.quant import * # noqa: F403, E402 +from zeta.nn.quant import * # noqa: F403, E402 from zeta.rl import * # noqa: F403, E402 from zeta.training import * # noqa: F403, E402 from zeta.utils import * # noqa: F403, E402 diff --git a/zeta/cli/__init__.py b/zeta/cli/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/zeta/cli/main.py b/zeta/cli/main.py deleted file mode 100644 index f10f4bd1..00000000 --- a/zeta/cli/main.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse - -from zeta.cloud.main import zetacloud - - -def main(): - """Main function for the CLI - - Args: - task_name (str, optional): _description_. Defaults to None. - cluster_name (str, optional): _description_. Defaults to "[ZetaTrainingRun]". - cloud (Any, optional): _description_. Defaults to AWS(). - gpus (str, optional): _description_. Defaults to None. - - Examples: - $ zetacloud -t "test" -c "[ZetaTrainingRun]" -cl AWS -g "1 V100" - - - """ - parser = argparse.ArgumentParser(description="Zetacloud CLI") - parser.add_argument("-t", "--task_name", type=str, help="Task name") - parser.add_argument( - "-c", - "--cluster_name", - type=str, - default="[ZetaTrainingRun]", - help="Cluster name", - ) - parser.add_argument( - "-cl", "--cloud", type=str, default="AWS", help="Cloud provider" - ) - parser.add_argument("-g", "--gpus", type=str, help="GPUs") - parser.add_argument( - "-f", "--filename", type=str, default="train.py", help="Filename" - ) - parser.add_argument("-s", "--stop", action="store_true", help="Stop flag") - parser.add_argument("-d", "--down", action="store_true", help="Down flag") - parser.add_argument( - "-sr", "--status_report", action="store_true", help="Status report flag" - ) - - # Generate API key - # parser.add_argument( - # "-k", "--generate_api_key", action="store_true", help="Generate key flag" - # ) - - # Sign In - # parser.add_argument( - # "-si", "--sign_in", action="store_true", help="Sign in flag" - # ) - - args = parser.parse_args() - - zetacloud( - task_name=args.task_name, - cluster_name=args.cluster_name, - cloud=args.cloud, - gpus=args.gpus, - filename=args.filename, - stop=args.stop, - down=args.down, - status_report=args.status_report, - ) - - -# if __name__ == "__main__": -# main() diff --git a/zeta/cloud/__init__.py b/zeta/cloud/__init__.py deleted file mode 100644 index fbdf0635..00000000 --- a/zeta/cloud/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""init file for cloud module""" - -from zeta.cloud.main import zetacloud -from zeta.cloud.sky_api import SkyInterface - -__all__ = ["zetacloud", "SkyInterface"] diff --git a/zeta/cloud/main.py b/zeta/cloud/main.py deleted file mode 100644 index f2c223d2..00000000 --- a/zeta/cloud/main.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Cloud""" - -import logging -from typing import Any - -from sky import AWS, Resources - -from zeta.cloud.sky_api import SkyInterface - -skyapi = SkyInterface(stream_logs_enabled=True) - - -# Logger -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -def zetacloud( - task_name: str = None, - cluster_name: str = "ZetaTrainingRun", - setup: str = "pip install -r requirements.txt", - cloud: Any = AWS(), - gpus: str = "V100:4", - filename: str = "train.py", - stop: bool = False, - down: bool = False, - status_report: bool = False, - *args, - **kwargs, -): - """zetacloud - - Args: - task_name (str, optional): _description_. Defaults to None. - cluster_name (str, optional): _description_. Defaults to "[ZetaTrainingRun]". - cloud (Any, optional): _description_. Defaults to AWS(). - gpus (str, optional): _description_. Defaults to None. - """ - try: - task = skyapi.create_task( - name=task_name, - setup=setup, - run=f"python {filename}", - workdir=".", - ) - logger.info(f"Task: {task} has been created") - - # Set the resources - task.set_resources(Resources(accelerators=gpus)) - # logger.info(f"Resources: {task.resources} have been set") - - # Execute the task on the cluster - execution = skyapi.launch(task, cluster_name) - print(execution) - logger.info( - f"Task: {task} has been launched on cluster: {cluster_name}" - ) - - if stop: - skyapi.stop(cluster_name) - logger.info(f"Cluster: {cluster_name} has been stopped") - - if down: - skyapi.down(cluster_name) - logger.info(f"Cluster: {cluster_name} has been deleted") - - if status_report: - skyapi.status(cluster_names=[cluster_name]) - logger.info(f"Cluster: {cluster_name} has been reported on") - - except Exception as error: - print( - f"There has been an error: {error} the root cause is:" - f" {error.__cause__}" - ) diff --git a/zeta/cloud/sky_api.py b/zeta/cloud/sky_api.py deleted file mode 100644 index b5e71ae1..00000000 --- a/zeta/cloud/sky_api.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import List - -import sky -from sky import Task - - -class SkyInterface: - """ - - SkyInterface is a wrapper around the sky Python API. It provides a - simplified interface for launching, executing, stopping, starting, and - tearing down clusters. - - Attributes: - clusters (dict): A dictionary of clusters that have been launched. - The keys are the names of the clusters and the values are the handles - to the clusters. - - Methods: - launch: Launch a cluster - execute: Execute a task on a cluster - stop: Stop a cluster - start: Start a cluster - down: Tear down a cluster - status: Get the status of a cluster - autostop: Set the autostop of a cluster - - Example: - >>> sky_interface = SkyInterface() - >>> job_id = sky_interface.launch("task", "cluster_name") - >>> sky_interface.execute("task", "cluster_name") - >>> sky_interface.stop("cluster_name") - >>> sky_interface.start("cluster_name") - >>> sky_interface.down("cluster_name") - >>> sky_interface.status() - >>> sky_interface.autostop("cluster_name") - - - """ - - def __init__( - self, - task_name: str = None, - cluster_name: str = None, - gpus: str = "T4:1", - stream_logs_enabled: bool = False, - *args, - **kwargs, - ): - self.task_name = task_name - self.cluster_name = cluster_name - self.gpus = gpus - self.stream_logs_enabled = stream_logs_enabled - self.clusters = {} - - def launch(self, task: Task = None, cluster_name: str = None, **kwargs): - """Launch a task on a cluster - - Args: - task (str): code to execute on the cluster - cluster_name (_type_, optional): _description_. Defaults to None. - - Returns: - _type_: _description_ - """ - cluster = None - try: - cluster = sky.launch( - task=task, - cluster_name=cluster_name, - stream_logs=self.stream_logs_enabled, - **kwargs, - ) - print(f"Launched job {cluster} on cluster {cluster_name}") - return cluster - except Exception as error: - # Deep error logging - print( - f"Error launching job {cluster} on cluster {cluster_name} with" - f" error {error}" - ) - raise error - - def execute(self, task: Task = None, cluster_name: str = None, **kwargs): - """Execute a task on a cluster - - Args: - task (_type_): _description_ - cluster_name (_type_): _description_ - - Raises: - ValueError: _description_ - - Returns: - _type_: _description_ - """ - if cluster_name not in self.clusters: - raise ValueError(f"Cluster {cluster_name} does not exist") - try: - return sky.exec( - task=task, - cluster_name=cluster_name, - stream_logs=self.stream_logs_enabled, - **kwargs, - ) - except Exception as e: - print("Error executing on cluster:", e) - - def stop(self, cluster_name: str = None, **kwargs): - """Stop a cluster - - Args: - cluster_name (str): name of the cluster to stop - """ - try: - sky.stop(cluster_name, **kwargs) - except (ValueError, RuntimeError) as e: - print("Error stopping cluster:", e) - - def start(self, cluster_name: str = None, **kwargs): - """start a cluster - - Args: - cluster_name (str): name of the cluster to start - """ - try: - sky.start(cluster_name, **kwargs) - except Exception as e: - print("Error starting cluster:", e) - - def down(self, cluster_name: str = None, **kwargs): - """Down a cluster - - Args: - cluster_name (str): name of the cluster to tear down - """ - try: - sky.down(cluster_name, **kwargs) - if cluster_name in self.clusters: - del self.clusters[cluster_name] - except (ValueError, RuntimeError) as e: - print("Error tearing down cluster:", e) - - def status(self, cluster_names: List[str] = None, **kwargs): - """Save a cluster - - Returns: - r: the status of the cluster - """ - try: - return sky.status(cluster_names, **kwargs) - except Exception as e: - print("Error getting status:", e) - - def autostop(self, cluster_name: str = None, **kwargs): - """Autostop a cluster - - Args: - cluster_name (str): name of the cluster to autostop - """ - try: - sky.autostop(cluster_name, **kwargs) - except Exception as e: - print("Error setting autostop:", e) - - def create_task( - self, - name: str = None, - setup: str = None, - run: str = None, - workdir: str = None, - task: str = None, - *args, - **kwargs, - ): - """_summary_ - - Args: - name (str, optional): _description_. Defaults to None. - setup (str, optional): _description_. Defaults to None. - run (str, optional): _description_. Defaults to None. - workdir (str, optional): _description_. Defaults to None. - task (str, optional): _description_. Defaults to None. - - Returns: - _type_: _description_ - - # A Task that will sync up local workdir '.', containing - # requirements.txt and train.py. - sky.Task(setup='pip install requirements.txt', - run='python train.py', - workdir='.') - - # An empty Task for provisioning a cluster. - task = sky.Task(num_nodes=n).set_resources(...) - - # Chaining setters. - sky.Task().set_resources(...).set_file_mounts(...) - """ - return Task( - name=name, setup=setup, run=run, workdir=workdir, *args, **kwargs - ) diff --git a/zeta/nn/embeddings/__init__.py b/zeta/nn/embeddings/__init__.py index 2f754087..432a9a2c 100644 --- a/zeta/nn/embeddings/__init__.py +++ b/zeta/nn/embeddings/__init__.py @@ -30,6 +30,7 @@ from zeta.nn.embeddings.scaled_sinusoidal_embeddings import ( ScaledSinusoidalEmbedding, ) +from zeta.nn.embeddings.mi_rope import MIRoPE __all__ = [ @@ -61,4 +62,5 @@ "duplicate_interleave", "VisionEmbedding", "ScaledSinusoidalEmbedding", + "MIRoPE", ] diff --git a/zeta/nn/embeddings/mi_rope.py b/zeta/nn/embeddings/mi_rope.py new file mode 100644 index 00000000..2cbd293f --- /dev/null +++ b/zeta/nn/embeddings/mi_rope.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +from typing import List + + +class MIRoPE(nn.Module): + def __init__(self, dim: int): + """ + Initializes the MI-RoPE module. + + Args: + dim (int): The dimension of the model's hidden states. + """ + super(MIRoPE, self).__init__() + self.dim = dim + + def forward( + self, + visual_features: List[torch.Tensor], + sequence_positions: List[int], + max_seq_len: int, + ) -> List[torch.Tensor]: + """ + Applies the Multimodal-Interleaved Rotary Position Embedding to visual features. + + Args: + visual_features (List[torch.Tensor]): A list of tensors containing the visual features for each image. + sequence_positions (List[int]): The positions of the images in the interleaved sequence. + max_seq_len (int): The maximum sequence length for the rotary position embedding. + + Returns: + List[torch.Tensor]: The visual features with applied rotary position embeddings. + """ + assert len(visual_features) == len( + sequence_positions + ), "Each image must have a corresponding position." + + # Generate the rotary position embedding + position_ids = torch.arange( + 0, max_seq_len, dtype=torch.float + ).unsqueeze(1) + half_dim = self.dim // 2 + + # Correct calculation for the embeddings + emb = torch.exp( + torch.arange(0, half_dim, dtype=torch.float) + * -(torch.log(torch.tensor(10000.0)) / half_dim) + ) + sin_emb = torch.sin(position_ids * emb) + cos_emb = torch.cos(position_ids * emb) + + # Concatenate sin and cos embeddings properly + rotary_emb = torch.cat( + [sin_emb, cos_emb], dim=1 + ) # This should have shape [max_seq_len, dim] + + embedded_visuals = [] + for i, visual in enumerate(visual_features): + position = sequence_positions[i] + # Apply the rotary position embedding based on the sequence position of the image + visual = self.apply_rotary_embedding(visual, rotary_emb[position]) + embedded_visuals.append(visual) + + return embedded_visuals + + def apply_rotary_embedding( + self, visual: torch.Tensor, rotary_embedding: torch.Tensor + ) -> torch.Tensor: + """ + Applies the rotary position embedding to a visual feature. + + Args: + visual (torch.Tensor): The visual feature tensor of shape (num_patches, dim). + rotary_embedding (torch.Tensor): The rotary embedding corresponding to the position in the sequence. + + Returns: + torch.Tensor: The visual feature tensor with rotary position embedding applied. + """ + return (visual * rotary_embedding.cos()) + ( + torch.roll(visual, shifts=1, dims=-1) * rotary_embedding.sin() + ) + + +# # Assuming batch size of 1 for simplicity, you can generalize this as needed +# batch_size = 1 +# dim = 512 +# max_seq_len = 100 + +# # Example inputs +# visual_features = [ +# torch.rand(batch_size, 10, dim), +# torch.rand(batch_size, 10, dim), +# ] # 10 patches per image +# sequence_positions = [5, 10] # Image positions in the interleaved sequence + +# # Initialize modules +# mi_rope = MIRoPE(dim=dim) +# # Apply MI-RoPE to each image's visual features +# embedded_visuals = mi_rope(visual_features, sequence_positions, max_seq_len) +# print(embedded_visuals[0].shape) # Output shape: torch.Size([1, 10, 512]) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 727afdd8..16a591fe 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -223,6 +223,7 @@ from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention from zeta.nn.modules.snake_act import Snake +from zeta.nn.modules.adaptive_gating import AdaptiveGating # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -449,4 +450,5 @@ "GatedMoECrossAttn", "GatedXAttention", "Snake", + "AdaptiveGating", ] diff --git a/zeta/nn/modules/adaptive_gating.py b/zeta/nn/modules/adaptive_gating.py new file mode 100644 index 00000000..267aa9af --- /dev/null +++ b/zeta/nn/modules/adaptive_gating.py @@ -0,0 +1,42 @@ +from torch import nn, Tensor + + +def exists(val): + return val is not None + + +class AdaptiveGating(nn.Module): + def __init__(self, hidden_dim: int): + """ + Initializes an instance of the AdaptiveGating class. + + Args: + hidden_dim (int): The dimension of the hidden state. + + """ + super().__init__() + self.hidden_dim = hidden_dim + self.sigmoid = nn.Sigmoid() + + def forward( + self, + hat_text: Tensor, + bar_text: Tensor, + ) -> Tensor: + """ + Performs the forward pass of the AdaptiveGating module. + + Args: + hat_text (Tensor): The input tensor representing the hat text. + bar_text (Tensor): The input tensor representing the bar text. + + Returns: + Tensor: The fused hidden state tensor. + + """ + g = self.sigmoid(hat_text) + + # Step 2 + h_fused = bar_text * g + hat_text * (1 - g) + + return h_fused diff --git a/zeta/nn/modules/quantized_layernorm.py b/zeta/nn/modules/quantized_layernorm.py index adfe1aed..7e3fb235 100644 --- a/zeta/nn/modules/quantized_layernorm.py +++ b/zeta/nn/modules/quantized_layernorm.py @@ -1,6 +1,6 @@ from torch import Tensor, nn -from zeta.quant.bitlinear import absmax_quantize +from zeta.nn.quant.bitlinear import absmax_quantize class QuantizedLN(nn.Module): diff --git a/zeta/nn/modules/scale.py b/zeta/nn/modules/scale.py index 443ab49a..fc2c6f35 100644 --- a/zeta/nn/modules/scale.py +++ b/zeta/nn/modules/scale.py @@ -23,7 +23,7 @@ class Scale(nn.Module): """ - def __init__(self, value, fn): + def __init__(self, value: float = 0.5, fn: callable = None): super().__init__() self.value = value self.fn = fn diff --git a/zeta/nn/modules/shift_tokens.py b/zeta/nn/modules/shift_tokens.py index 0293be87..ce5d71f4 100644 --- a/zeta/nn/modules/shift_tokens.py +++ b/zeta/nn/modules/shift_tokens.py @@ -1,6 +1,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch import Tensor def pad_at_dim(t, pad, dim=-1, value=0.0): @@ -47,12 +48,12 @@ class ShiftTokens(nn.Module): """ - def __init__(self, shifts, fn): + def __init__(self, shifts: list = None, fn: callable = None): super().__init__() self.fn = fn self.shifts = tuple(shifts) - def forward(self, x, **kwargs): + def forward(self, x: Tensor, **kwargs): """Forward method of ShiftTokens""" mask = kwargs.get("mask", None) shifts = self.shifts diff --git a/zeta/nn/quant/__init__.py b/zeta/nn/quant/__init__.py new file mode 100644 index 00000000..2069ebb2 --- /dev/null +++ b/zeta/nn/quant/__init__.py @@ -0,0 +1,19 @@ +from zeta.nn.quant.absmax import absmax_quantize +from zeta.nn.quant.bitlinear import BitLinear +from zeta.nn.quant.half_bit_linear import HalfBitLinear +from zeta.nn.quant.lfq import LFQ +from zeta.nn.quant.niva import niva +from zeta.nn.quant.qlora import QloraLinear +from zeta.nn.quant.quick import QUIK +from zeta.nn.quant.ste import STE + +__all__ = [d + "QUIK", + "absmax_quantize", + "BitLinear", + "STE", + "QloraLinear", + "niva", + "HalfBitLinear", + "LFQ", +] diff --git a/zeta/quant/absmax.py b/zeta/nn/quant/absmax.py similarity index 100% rename from zeta/quant/absmax.py rename to zeta/nn/quant/absmax.py diff --git a/zeta/quant/bitlinear.py b/zeta/nn/quant/bitlinear.py similarity index 97% rename from zeta/quant/bitlinear.py rename to zeta/nn/quant/bitlinear.py index 66ba7f8e..24fea175 100644 --- a/zeta/quant/bitlinear.py +++ b/zeta/nn/quant/bitlinear.py @@ -35,7 +35,7 @@ class BitLinear(nn.Module): Usage: >>> import torch - >>> from zeta.quant.bitlinear import BitLinear + >>> from zeta.nn.quant.bitlinear import BitLinear >>> linear = BitLinear(10, 20) >>> input = torch.randn(128, 10) >>> output = linear(input) diff --git a/zeta/quant/half_bit_linear.py b/zeta/nn/quant/half_bit_linear.py similarity index 100% rename from zeta/quant/half_bit_linear.py rename to zeta/nn/quant/half_bit_linear.py diff --git a/zeta/quant/lfq.py b/zeta/nn/quant/lfq.py similarity index 100% rename from zeta/quant/lfq.py rename to zeta/nn/quant/lfq.py diff --git a/zeta/quant/niva.py b/zeta/nn/quant/niva.py similarity index 98% rename from zeta/quant/niva.py rename to zeta/nn/quant/niva.py index 9f9dce0e..210f1840 100644 --- a/zeta/quant/niva.py +++ b/zeta/nn/quant/niva.py @@ -26,7 +26,7 @@ def niva( Examples: >>> import torch - >>> from zeta.quant import niva + >>> from zeta.nn.quant import niva >>> from zeta.nn import QFTSPEmbedding >>> model = QFTSPEmbedding(100, 100) >>> niva( diff --git a/zeta/quant/qlora.py b/zeta/nn/quant/qlora.py similarity index 99% rename from zeta/quant/qlora.py rename to zeta/nn/quant/qlora.py index 203160c6..8d810ba0 100644 --- a/zeta/quant/qlora.py +++ b/zeta/nn/quant/qlora.py @@ -635,7 +635,7 @@ class QloraLinear(nn.Module): Example: import torch - from zeta.quant.qlora import QloraLinear + from zeta.nn.quant.qlora import QloraLinear # Convert the weight tensor to torch.bfloat16 weight_bfloat16 = torch.rand(4096, 4096).to(torch.bfloat16) diff --git a/zeta/quant/qmoe.py b/zeta/nn/quant/qmoe.py similarity index 100% rename from zeta/quant/qmoe.py rename to zeta/nn/quant/qmoe.py diff --git a/zeta/quant/quick.py b/zeta/nn/quant/quick.py similarity index 100% rename from zeta/quant/quick.py rename to zeta/nn/quant/quick.py diff --git a/zeta/quant/random_proj_quan.py b/zeta/nn/quant/random_proj_quan.py similarity index 100% rename from zeta/quant/random_proj_quan.py rename to zeta/nn/quant/random_proj_quan.py diff --git a/zeta/quant/residual_vq.py b/zeta/nn/quant/residual_vq.py similarity index 100% rename from zeta/quant/residual_vq.py rename to zeta/nn/quant/residual_vq.py diff --git a/zeta/quant/ste.py b/zeta/nn/quant/ste.py similarity index 100% rename from zeta/quant/ste.py rename to zeta/nn/quant/ste.py diff --git a/zeta/quant/__init__.py b/zeta/quant/__init__.py deleted file mode 100644 index 7dbcc5aa..00000000 --- a/zeta/quant/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from zeta.quant.absmax import absmax_quantize -from zeta.quant.bitlinear import BitLinear -from zeta.quant.half_bit_linear import HalfBitLinear -from zeta.quant.lfq import LFQ -from zeta.quant.niva import niva -from zeta.quant.qlora import QloraLinear -from zeta.quant.quick import QUIK -from zeta.quant.ste import STE - -__all__ = [ - "QUIK", - "absmax_quantize", - "BitLinear", - "STE", - "QloraLinear", - "niva", - "HalfBitLinear", - "LFQ", -]