diff --git a/.gitignore b/.gitignore index 267e47d8..f3d7b117 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ data # Distribution / packaging .Python .DS_Store +agent_workspace build/ .ruff_cache .vscode diff --git a/audio_encoder.py b/audio_encoder.py new file mode 100644 index 00000000..51d9cb5d --- /dev/null +++ b/audio_encoder.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +from torchaudio.transforms import MelSpectrogram +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class AudioEncoder(nn.Module): + """ + Audio Encoder class that processes audio input through a Mel Filter Bank, CNN downsampling layers, + and a Transformer encoder. The output is then passed through a simple two-layer MLP to encode + each 2 seconds of audio input into 25 tokens. + + Args: + n_mels (int): Number of mel frequency bins. Default is 128. + cnn_channels (int): Number of channels in the CNN layers. Default is 64. + transformer_layers (int): Number of layers in the Transformer. Default is 24. + nhead (int): Number of heads in the multiheadattention models. Default is 8. + dim_feedforward (int): The dimension of the feedforward network model in nn.TransformerEncoder. Default is 1024. + audio_length (int): Length of the input audio in seconds. Default is 2. + mlp_hidden_dim (int): Dimension of the hidden layer in the MLP. Default is 256. + output_dim (int): Dimension of the output tokens. Default is 25. + """ + + def __init__( + self, + n_mels: int = 128, + cnn_channels: int = 64, + transformer_layers: int = 24, + nhead: int = 8, + dim_feedforward: int = 1024, + audio_length: int = 2, + mlp_hidden_dim: int = 256, + output_dim: int = 25, + ): + super(AudioEncoder, self).__init__() + + self.mel_spectrogram = MelSpectrogram(sample_rate=16000, n_mels=n_mels) + + self.cnn = nn.Sequential( + nn.Conv2d(1, cnn_channels, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d( + cnn_channels, + cnn_channels * 2, + kernel_size=3, + stride=2, + padding=1, + ), + nn.ReLU(), + nn.Conv2d( + cnn_channels * 2, + cnn_channels * 4, + kernel_size=3, + stride=2, + padding=1, + ), + nn.ReLU(), + nn.Conv2d( + cnn_channels * 4, + cnn_channels * 8, + kernel_size=3, + stride=2, + padding=1, + ), + nn.ReLU(), + ) + + transformer_encoder_layer = TransformerEncoderLayer( + d_model=cnn_channels * 8, + nhead=nhead, + dim_feedforward=dim_feedforward, + ) + self.transformer_encoder = TransformerEncoder( + transformer_encoder_layer, num_layers=transformer_layers + ) + + self.mlp = nn.Sequential( + nn.Linear(cnn_channels * 8, mlp_hidden_dim), + nn.ReLU(), + nn.Linear(mlp_hidden_dim, output_dim), + ) + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the AudioEncoder. + + Args: + audio (torch.Tensor): Input audio tensor of shape (batch_size, num_samples). + + Returns: + torch.Tensor: Encoded audio tensor of shape (batch_size, num_tokens, output_dim). + """ + # Convert raw audio to Mel Spectrogram + mel_spec = self.mel_spectrogram(audio).unsqueeze( + 1 + ) # Add channel dimension + + # Pass through CNN layers + cnn_out = self.cnn(mel_spec) + + # Flatten CNN output for transformer + batch_size, channels, height, width = cnn_out.size() + cnn_out = cnn_out.permute(0, 2, 3, 1).reshape(batch_size, -1, channels) + + # Pass through Transformer + transformer_out = self.transformer_encoder(cnn_out) + + # Pass through MLP + output = self.mlp(transformer_out) + + return output + + +# Example usage: +if __name__ == "__main__": + # Assume 2 seconds of audio with 16kHz sample rate + audio_input = torch.randn(8, 32000) # batch_size = 8, num_samples = 32000 + + model = AudioEncoder() + output = model(audio_input) + print(output) # Should output (batch_size, num_tokens, output_dim) diff --git a/playground/models/cobra.py b/examples/models/cobra.py similarity index 100% rename from playground/models/cobra.py rename to examples/models/cobra.py diff --git a/playground/models/gpt4.py b/examples/models/gpt4.py similarity index 100% rename from playground/models/gpt4.py rename to examples/models/gpt4.py diff --git a/playground/models/gpt4_multimodal.py b/examples/models/gpt4_multimodal.py similarity index 100% rename from playground/models/gpt4_multimodal.py rename to examples/models/gpt4_multimodal.py diff --git a/playground/models/nirvana.py b/examples/models/nirvana.py similarity index 100% rename from playground/models/nirvana.py rename to examples/models/nirvana.py diff --git a/playground/models/simple_transformer.py b/examples/models/simple_transformer.py similarity index 100% rename from playground/models/simple_transformer.py rename to examples/models/simple_transformer.py diff --git a/playground/models/toka_master_gpt.py b/examples/models/toka_master_gpt.py similarity index 100% rename from playground/models/toka_master_gpt.py rename to examples/models/toka_master_gpt.py diff --git a/playground/models/videos/spectra.py b/examples/models/videos/spectra.py similarity index 100% rename from playground/models/videos/spectra.py rename to examples/models/videos/spectra.py diff --git a/playground/modules/cross_attend.py b/examples/modules/cross_attend.py similarity index 100% rename from playground/modules/cross_attend.py rename to examples/modules/cross_attend.py diff --git a/playground/modules/flash_attention.py b/examples/modules/flash_attention.py similarity index 100% rename from playground/modules/flash_attention.py rename to examples/modules/flash_attention.py diff --git a/playground/modules/fractoral_norm.py b/examples/modules/fractoral_norm.py similarity index 100% rename from playground/modules/fractoral_norm.py rename to examples/modules/fractoral_norm.py diff --git a/playground/modules/viusal_expert_example.py b/examples/modules/viusal_expert_example.py similarity index 100% rename from playground/modules/viusal_expert_example.py rename to examples/modules/viusal_expert_example.py diff --git a/playground/ops/laplace.py b/examples/ops/laplace.py similarity index 100% rename from playground/ops/laplace.py rename to examples/ops/laplace.py diff --git a/playground/structs/transformer.py b/examples/structs/transformer.py similarity index 100% rename from playground/structs/transformer.py rename to examples/structs/transformer.py diff --git a/playground/todo/dit_block.py b/examples/todo/dit_block.py similarity index 100% rename from playground/todo/dit_block.py rename to examples/todo/dit_block.py diff --git a/playground/todo/hyper_attention.py b/examples/todo/hyper_attention.py similarity index 100% rename from playground/todo/hyper_attention.py rename to examples/todo/hyper_attention.py diff --git a/playground/todo/multi_head_latent_attention.py b/examples/todo/multi_head_latent_attention.py similarity index 100% rename from playground/todo/multi_head_latent_attention.py rename to examples/todo/multi_head_latent_attention.py diff --git a/playground/tokenizers/token_monster.py b/examples/tokenizers/token_monster.py similarity index 100% rename from playground/tokenizers/token_monster.py rename to examples/tokenizers/token_monster.py diff --git a/playground/training/fsdp.py b/examples/training/fsdp.py similarity index 100% rename from playground/training/fsdp.py rename to examples/training/fsdp.py diff --git a/pyproject.toml b/pyproject.toml index 2a1cf597..cef23def 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.6.2" +version = "2.6.7" description = "Rapidly Build, Optimize, and Train SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 16a591fe..1aac607c 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -224,6 +224,7 @@ from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention from zeta.nn.modules.snake_act import Snake from zeta.nn.modules.adaptive_gating import AdaptiveGating +from zeta.nn.modules.crome_adapter import CROMEAdapter # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -451,4 +452,5 @@ "GatedXAttention", "Snake", "AdaptiveGating", + "CROMEAdapter", ] diff --git a/zeta/nn/modules/crome_adapter.py b/zeta/nn/modules/crome_adapter.py new file mode 100644 index 00000000..01c5fe87 --- /dev/null +++ b/zeta/nn/modules/crome_adapter.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +from typing import Tuple + + +class CROMEAdapter(nn.Module): + def __init__(self, input_dim: int, bottleneck_dim: int): + """ + Initialize the CROMEAdapter module. + + Args: + input_dim (int): The dimension of the input features. + bottleneck_dim (int): The dimension of the bottleneck layer. + """ + super(CROMEAdapter, self).__init__() + + self.Wd_text = nn.Linear(input_dim, bottleneck_dim) + self.Wg_text = nn.Linear(input_dim, bottleneck_dim) + self.Wd_image = nn.Linear(input_dim, bottleneck_dim) + self.Wg_image = nn.Linear(input_dim, bottleneck_dim) + + self.Wu = nn.Linear(bottleneck_dim, input_dim) + + self.silu = nn.SiLU() + + def forward( + self, text_features: torch.Tensor, image_features: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Perform forward pass of the CROMEAdapter module. + + Args: + text_features (torch.Tensor): The input text features. + image_features (torch.Tensor): The input image features. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output text and image features. + """ + text_down = self.silu(self.Wd_text(text_features)) * self.Wg_text( + text_features + ) + image_down = self.silu(self.Wd_image(image_features)) * self.Wg_image( + image_features + ) + text_up = self.Wu(text_down) + image_up = self.Wu(image_down) + text_output = text_features + text_up + image_output = image_features + image_up + + return text_output, image_output + + +# model = CROMEAdapter(512, 256) +# text_features = torch.randn(1, 2, 512) +# image_features = torch.randn(1, 2, 512) +# output_text, output_image = model(text_features, image_features) +# print(output_text.shape, output_image.shape)