diff --git a/tests/nn/modules/test_adaptive_rmsnorm.py b/tests/nn/modules/test_adaptive_rmsnorm.py new file mode 100644 index 00000000..7670bcd6 --- /dev/null +++ b/tests/nn/modules/test_adaptive_rmsnorm.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +from zeta.nn.modules.adaptive_rmsnorm import AdaptiveRMSNorm + + +def test_adaptive_rmsnorm_init(): + arn = AdaptiveRMSNorm(10, dim_cond=5) + assert isinstance(arn, AdaptiveRMSNorm) + assert arn.dim_cond == 5 + assert arn.channel_first == False + assert arn.scale == 10**0.5 + assert isinstance(arn.to_gamma, nn.Linear) + assert arn.to_bias is None + + +def test_adaptive_rmsnorm_init_with_bias(): + arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True) + assert isinstance(arn.to_bias, nn.Linear) + + +def test_adaptive_rmsnorm_forward(): + arn = AdaptiveRMSNorm(10, dim_cond=5) + x = torch.randn(2, 10) + cond = torch.randn(2, 5) + output = arn.forward(x, cond=cond) + assert output.shape == (2, 10) + + +def test_adaptive_rmsnorm_forward_with_bias(): + arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True) + x = torch.randn(2, 10) + cond = torch.randn(2, 5) + output = arn.forward(x, cond=cond) + assert output.shape == (2, 10) + + +def test_adaptive_rmsnorm_forward_channel_first(): + arn = AdaptiveRMSNorm(10, dim_cond=5, channel_first=True) + x = torch.randn(2, 10, 3, 3) + cond = torch.randn(2, 5) + output = arn.forward(x, cond=cond) + assert output.shape == (2, 10, 3, 3) diff --git a/tests/structs/test_simple_vision_encoder.py b/tests/structs/test_simple_vision_encoder.py new file mode 100644 index 00000000..344698db --- /dev/null +++ b/tests/structs/test_simple_vision_encoder.py @@ -0,0 +1,27 @@ +import torch +from zeta.structs.simple_vision_encoder import SimpleVisionEncoder + + +def test_simple_vision_encoder_init(): + sve = SimpleVisionEncoder() + assert sve.size == (384, 384) + assert sve.model_name == "vikhyatk/moondream0" + assert sve.return_shape == False + assert isinstance(sve.model, torch.jit.ScriptModule) + assert sve.preprocess.transforms[-1].scale == True + assert sve.preprocess.transforms[-1].dtype == torch.float32 + + +def test_simple_vision_encoder_init_custom_size(): + sve = SimpleVisionEncoder(size=(512, 512)) + assert sve.size == (512, 512) + + +def test_simple_vision_encoder_init_custom_model_name(): + sve = SimpleVisionEncoder(model_name="custom/model") + assert sve.model_name == "custom/model" + + +def test_simple_vision_encoder_init_return_shape(): + sve = SimpleVisionEncoder(return_shape=True) + assert sve.return_shape == True diff --git a/zeta/nn/attention/linear_attention.py b/zeta/nn/attention/linear_attention.py index a01bf345..61747283 100644 --- a/zeta/nn/attention/linear_attention.py +++ b/zeta/nn/attention/linear_attention.py @@ -69,4 +69,3 @@ def forward(self, fmap): out = self.nonlin(out) return self.to_out(out) - diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 84f1ecad..22004883 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -77,7 +77,7 @@ from zeta.nn.modules.quantized_layernorm import QuantizedLN from zeta.nn.modules.slerp_model_merger import SLERPModelMerger from zeta.nn.modules.avg_model_merger import AverageModelMerger - +from zeta.nn.modules.adaptive_rmsnorm import AdaptiveRMSNorm # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -168,4 +168,5 @@ "QuantizedLN", "SLERPModelMerger", "AverageModelMerger", + "AdaptiveRMSNorm", ] diff --git a/zeta/nn/modules/adaptive_rmsnorm.py b/zeta/nn/modules/adaptive_rmsnorm.py new file mode 100644 index 00000000..8960e313 --- /dev/null +++ b/zeta/nn/modules/adaptive_rmsnorm.py @@ -0,0 +1,77 @@ +from torch import nn, Tensor +from beartype import beartype +import torch.nn.functional as F + + +def exists(val): + return val is not None + + +def append_dims(t, ndims: int): + return t.reshape(*t.shape, *((1,) * ndims)) + + +class AdaptiveRMSNorm(nn.Module): + """ + Adaptive Root Mean Square Normalization (RMSNorm) module. + + Args: + dim (int): The input dimension. + dim_cond (int): The dimension of the conditioning tensor. + channel_first (bool, optional): Whether the input has channels as the first dimension. Defaults to False. + images (bool, optional): Whether the input represents images. Defaults to False. + bias (bool, optional): Whether to include a bias term. Defaults to False. + """ + + def __init__( + self, dim, *, dim_cond, channel_first=False, images=False, bias=False + ): + super().__init__() + + self.dim_cond = dim_cond + self.channel_first = channel_first + self.scale = dim**0.5 + + self.to_gamma = nn.Linear(dim_cond, dim) + self.to_bias = nn.Linear(dim_cond, dim) if bias else None + + nn.init.zeros_(self.to_gamma.weight) + nn.init.ones_(self.to_gamma.bias) + + if bias: + nn.init.zeros_(self.to_bias.weight) + nn.init.zeros_(self.to_bias.bias) + + @beartype + def forward(self, x: Tensor, *, cond: Tensor): + """ + Forward pass of the AdaptiveRMSNorm module. + + Args: + x (torch.Tensor): The input tensor. + cond (torch.Tensor): The conditioning tensor. + + Returns: + torch.Tensor: The normalized and conditioned output tensor. + """ + batch = x.shape[0] + assert cond.shape == (batch, self.dim_cond) + + gamma = self.to_gamma(cond) + + bias = 0.0 + if exists(self.to_bias): + bias = self.to_bias(cond) + + if self.channel_first: + gamma = append_dims(gamma, x.ndim - 2) + + if exists(self.to_bias): + bias = append_dims(bias, x.ndim - 2) + + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) + * self.scale + * gamma + + bias + ) diff --git a/zeta/structs/__init__.py b/zeta/structs/__init__.py index 34e55212..41a1b353 100644 --- a/zeta/structs/__init__.py +++ b/zeta/structs/__init__.py @@ -20,8 +20,7 @@ ViTransformerWrapper, ) from zeta.structs.transformer_block import TransformerBlock - -# from zeta.structs.efficient_net import EfficientNet +from zeta.structs.simple_vision_encoder import VisionEncoder __all__ = [ "AutoregressiveWrapper", @@ -41,4 +40,5 @@ "CLIPVisionTower", "build_vision_tower", "build_vision_projector", + "VisionEncoder", ] diff --git a/zeta/structs/simple_vision_encoder.py b/zeta/structs/simple_vision_encoder.py new file mode 100644 index 00000000..007efa5e --- /dev/null +++ b/zeta/structs/simple_vision_encoder.py @@ -0,0 +1,83 @@ +import torch +from PIL import Image +from torchvision.transforms.v2 import ( + Compose, + Resize, + InterpolationMode, + ToImage, + ToDtype, + Normalize, +) +from typing import Tuple +from torch import nn +from huggingface_hub import snapshot_download + + +class VisionEncoder(nn.Module): + """ + Initializes a VisionEncoder object. + + Args: + size (Tuple, optional): The size of the input image. Defaults to (384, 384). + model_path (str, optional): The path to the pre-trained vision model. Defaults to "model". + return_shape (bool, optional): Whether to return the shape of the embedding. Defaults to False. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Examples:: + >>> from zeta.structs import VisionEncoder + >>> encoder = VisionEncoder() + >>> embeds = encoder("image.jpg") + >>> embeds.shape + torch.Size([1, 512]) + """ + + def __init__( + self, + size: Tuple = (384, 384), + model_name: str = "vikhyatk/moondream0", + return_shape: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__() + self.size = size + self.model_name = model_name + self.return_shape = return_shape + model_path = snapshot_download(model_name) + + self.model = torch.jit.load(f"{model_path}/vision.pt").to( + dtype=torch.float32 + ) + + self.preprocess = Compose( + [ + Resize(size=size, interpolation=InterpolationMode.BICUBIC), + ToImage(), + ToDtype(torch.float32, scale=True), + Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + *args, + ] + ) + + def __call__(self, image: Image, *args, **kwargs) -> torch.Tensor: + """ + Processes an input image and returns its embedding. + + Args: + image (Image): The input image. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + torch.Tensor: The embedding of the input image. + """ + image = Image.open(image) + with torch.no_grad(): + image_vec = self.preprocess(image.convert("RGB")).unsqueeze(0) + embeds = self.model(image_vec, *args, **kwargs) + + if self.return_shape: + print(f"Embedding shape: {embeds.shape}") + + return embeds