Skip to content

Commit

Permalink
[EXAMPLES]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Aug 28, 2024
1 parent b803a3d commit e28fe95
Show file tree
Hide file tree
Showing 23 changed files with 182 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ data
# Distribution / packaging
.Python
.DS_Store
agent_workspace
build/
.ruff_cache
.vscode
Expand Down
121 changes: 121 additions & 0 deletions audio_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import torch.nn as nn
from torchaudio.transforms import MelSpectrogram
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class AudioEncoder(nn.Module):
"""
Audio Encoder class that processes audio input through a Mel Filter Bank, CNN downsampling layers,
and a Transformer encoder. The output is then passed through a simple two-layer MLP to encode
each 2 seconds of audio input into 25 tokens.
Args:
n_mels (int): Number of mel frequency bins. Default is 128.
cnn_channels (int): Number of channels in the CNN layers. Default is 64.
transformer_layers (int): Number of layers in the Transformer. Default is 24.
nhead (int): Number of heads in the multiheadattention models. Default is 8.
dim_feedforward (int): The dimension of the feedforward network model in nn.TransformerEncoder. Default is 1024.
audio_length (int): Length of the input audio in seconds. Default is 2.
mlp_hidden_dim (int): Dimension of the hidden layer in the MLP. Default is 256.
output_dim (int): Dimension of the output tokens. Default is 25.
"""

def __init__(
self,
n_mels: int = 128,
cnn_channels: int = 64,
transformer_layers: int = 24,
nhead: int = 8,
dim_feedforward: int = 1024,
audio_length: int = 2,
mlp_hidden_dim: int = 256,
output_dim: int = 25,
):
super(AudioEncoder, self).__init__()

self.mel_spectrogram = MelSpectrogram(sample_rate=16000, n_mels=n_mels)

self.cnn = nn.Sequential(
nn.Conv2d(1, cnn_channels, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(
cnn_channels,
cnn_channels * 2,
kernel_size=3,
stride=2,
padding=1,
),
nn.ReLU(),
nn.Conv2d(
cnn_channels * 2,
cnn_channels * 4,
kernel_size=3,
stride=2,
padding=1,
),
nn.ReLU(),
nn.Conv2d(
cnn_channels * 4,
cnn_channels * 8,
kernel_size=3,
stride=2,
padding=1,
),
nn.ReLU(),
)

transformer_encoder_layer = TransformerEncoderLayer(
d_model=cnn_channels * 8,
nhead=nhead,
dim_feedforward=dim_feedforward,
)
self.transformer_encoder = TransformerEncoder(
transformer_encoder_layer, num_layers=transformer_layers
)

self.mlp = nn.Sequential(
nn.Linear(cnn_channels * 8, mlp_hidden_dim),
nn.ReLU(),
nn.Linear(mlp_hidden_dim, output_dim),
)

def forward(self, audio: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the AudioEncoder.
Args:
audio (torch.Tensor): Input audio tensor of shape (batch_size, num_samples).
Returns:
torch.Tensor: Encoded audio tensor of shape (batch_size, num_tokens, output_dim).
"""
# Convert raw audio to Mel Spectrogram
mel_spec = self.mel_spectrogram(audio).unsqueeze(
1
) # Add channel dimension

# Pass through CNN layers
cnn_out = self.cnn(mel_spec)

# Flatten CNN output for transformer
batch_size, channels, height, width = cnn_out.size()
cnn_out = cnn_out.permute(0, 2, 3, 1).reshape(batch_size, -1, channels)

# Pass through Transformer
transformer_out = self.transformer_encoder(cnn_out)

# Pass through MLP
output = self.mlp(transformer_out)

return output


# Example usage:
if __name__ == "__main__":
# Assume 2 seconds of audio with 16kHz sample rate
audio_input = torch.randn(8, 32000) # batch_size = 8, num_samples = 32000

model = AudioEncoder()
output = model(audio_input)
print(output) # Should output (batch_size, num_tokens, output_dim)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.6.2"
version = "2.6.7"
description = "Rapidly Build, Optimize, and Train SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
from zeta.nn.modules.evlm_xattn import GatedMoECrossAttn, GatedXAttention
from zeta.nn.modules.snake_act import Snake
from zeta.nn.modules.adaptive_gating import AdaptiveGating
from zeta.nn.modules.crome_adapter import CROMEAdapter

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -451,4 +452,5 @@
"GatedXAttention",
"Snake",
"AdaptiveGating",
"CROMEAdapter",
]
57 changes: 57 additions & 0 deletions zeta/nn/modules/crome_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
from typing import Tuple


class CROMEAdapter(nn.Module):
def __init__(self, input_dim: int, bottleneck_dim: int):
"""
Initialize the CROMEAdapter module.
Args:
input_dim (int): The dimension of the input features.
bottleneck_dim (int): The dimension of the bottleneck layer.
"""
super(CROMEAdapter, self).__init__()

self.Wd_text = nn.Linear(input_dim, bottleneck_dim)
self.Wg_text = nn.Linear(input_dim, bottleneck_dim)
self.Wd_image = nn.Linear(input_dim, bottleneck_dim)
self.Wg_image = nn.Linear(input_dim, bottleneck_dim)

self.Wu = nn.Linear(bottleneck_dim, input_dim)

self.silu = nn.SiLU()

def forward(
self, text_features: torch.Tensor, image_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform forward pass of the CROMEAdapter module.
Args:
text_features (torch.Tensor): The input text features.
image_features (torch.Tensor): The input image features.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output text and image features.
"""
text_down = self.silu(self.Wd_text(text_features)) * self.Wg_text(
text_features
)
image_down = self.silu(self.Wd_image(image_features)) * self.Wg_image(
image_features
)
text_up = self.Wu(text_down)
image_up = self.Wu(image_down)
text_output = text_features + text_up
image_output = image_features + image_up

return text_output, image_output


# model = CROMEAdapter(512, 256)
# text_features = torch.randn(1, 2, 512)
# image_features = torch.randn(1, 2, 512)
# output_text, output_image = model(text_features, image_features)
# print(output_text.shape, output_image.shape)

0 comments on commit e28fe95

Please sign in to comment.