Skip to content

Commit

Permalink
[FEAT][VisionEncoder]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 30, 2023
1 parent 57f1e82 commit 5fbe1c9
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 4 deletions.
42 changes: 42 additions & 0 deletions tests/nn/modules/test_adaptive_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
import torch.nn as nn
from zeta.nn.modules.adaptive_rmsnorm import AdaptiveRMSNorm


def test_adaptive_rmsnorm_init():
arn = AdaptiveRMSNorm(10, dim_cond=5)
assert isinstance(arn, AdaptiveRMSNorm)
assert arn.dim_cond == 5
assert arn.channel_first == False
assert arn.scale == 10**0.5
assert isinstance(arn.to_gamma, nn.Linear)
assert arn.to_bias is None


def test_adaptive_rmsnorm_init_with_bias():
arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True)
assert isinstance(arn.to_bias, nn.Linear)


def test_adaptive_rmsnorm_forward():
arn = AdaptiveRMSNorm(10, dim_cond=5)
x = torch.randn(2, 10)
cond = torch.randn(2, 5)
output = arn.forward(x, cond=cond)
assert output.shape == (2, 10)


def test_adaptive_rmsnorm_forward_with_bias():
arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True)
x = torch.randn(2, 10)
cond = torch.randn(2, 5)
output = arn.forward(x, cond=cond)
assert output.shape == (2, 10)


def test_adaptive_rmsnorm_forward_channel_first():
arn = AdaptiveRMSNorm(10, dim_cond=5, channel_first=True)
x = torch.randn(2, 10, 3, 3)
cond = torch.randn(2, 5)
output = arn.forward(x, cond=cond)
assert output.shape == (2, 10, 3, 3)
27 changes: 27 additions & 0 deletions tests/structs/test_simple_vision_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from zeta.structs.simple_vision_encoder import SimpleVisionEncoder


def test_simple_vision_encoder_init():
sve = SimpleVisionEncoder()
assert sve.size == (384, 384)
assert sve.model_name == "vikhyatk/moondream0"
assert sve.return_shape == False
assert isinstance(sve.model, torch.jit.ScriptModule)
assert sve.preprocess.transforms[-1].scale == True
assert sve.preprocess.transforms[-1].dtype == torch.float32


def test_simple_vision_encoder_init_custom_size():
sve = SimpleVisionEncoder(size=(512, 512))
assert sve.size == (512, 512)


def test_simple_vision_encoder_init_custom_model_name():
sve = SimpleVisionEncoder(model_name="custom/model")
assert sve.model_name == "custom/model"


def test_simple_vision_encoder_init_return_shape():
sve = SimpleVisionEncoder(return_shape=True)
assert sve.return_shape == True
1 change: 0 additions & 1 deletion zeta/nn/attention/linear_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,3 @@ def forward(self, fmap):

out = self.nonlin(out)
return self.to_out(out)

3 changes: 2 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from zeta.nn.modules.quantized_layernorm import QuantizedLN
from zeta.nn.modules.slerp_model_merger import SLERPModelMerger
from zeta.nn.modules.avg_model_merger import AverageModelMerger

from zeta.nn.modules.adaptive_rmsnorm import AdaptiveRMSNorm

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -168,4 +168,5 @@
"QuantizedLN",
"SLERPModelMerger",
"AverageModelMerger",
"AdaptiveRMSNorm",
]
77 changes: 77 additions & 0 deletions zeta/nn/modules/adaptive_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from torch import nn, Tensor
from beartype import beartype
import torch.nn.functional as F


def exists(val):
return val is not None


def append_dims(t, ndims: int):
return t.reshape(*t.shape, *((1,) * ndims))


class AdaptiveRMSNorm(nn.Module):
"""
Adaptive Root Mean Square Normalization (RMSNorm) module.
Args:
dim (int): The input dimension.
dim_cond (int): The dimension of the conditioning tensor.
channel_first (bool, optional): Whether the input has channels as the first dimension. Defaults to False.
images (bool, optional): Whether the input represents images. Defaults to False.
bias (bool, optional): Whether to include a bias term. Defaults to False.
"""

def __init__(
self, dim, *, dim_cond, channel_first=False, images=False, bias=False
):
super().__init__()

self.dim_cond = dim_cond
self.channel_first = channel_first
self.scale = dim**0.5

self.to_gamma = nn.Linear(dim_cond, dim)
self.to_bias = nn.Linear(dim_cond, dim) if bias else None

nn.init.zeros_(self.to_gamma.weight)
nn.init.ones_(self.to_gamma.bias)

if bias:
nn.init.zeros_(self.to_bias.weight)
nn.init.zeros_(self.to_bias.bias)

@beartype
def forward(self, x: Tensor, *, cond: Tensor):
"""
Forward pass of the AdaptiveRMSNorm module.
Args:
x (torch.Tensor): The input tensor.
cond (torch.Tensor): The conditioning tensor.
Returns:
torch.Tensor: The normalized and conditioned output tensor.
"""
batch = x.shape[0]
assert cond.shape == (batch, self.dim_cond)

gamma = self.to_gamma(cond)

bias = 0.0
if exists(self.to_bias):
bias = self.to_bias(cond)

if self.channel_first:
gamma = append_dims(gamma, x.ndim - 2)

if exists(self.to_bias):
bias = append_dims(bias, x.ndim - 2)

return (
F.normalize(x, dim=(1 if self.channel_first else -1))
* self.scale
* gamma
+ bias
)
4 changes: 2 additions & 2 deletions zeta/structs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
ViTransformerWrapper,
)
from zeta.structs.transformer_block import TransformerBlock

# from zeta.structs.efficient_net import EfficientNet
from zeta.structs.simple_vision_encoder import VisionEncoder

__all__ = [
"AutoregressiveWrapper",
Expand All @@ -41,4 +40,5 @@
"CLIPVisionTower",
"build_vision_tower",
"build_vision_projector",
"VisionEncoder",
]
83 changes: 83 additions & 0 deletions zeta/structs/simple_vision_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from PIL import Image
from torchvision.transforms.v2 import (
Compose,
Resize,
InterpolationMode,
ToImage,
ToDtype,
Normalize,
)
from typing import Tuple
from torch import nn
from huggingface_hub import snapshot_download


class VisionEncoder(nn.Module):
"""
Initializes a VisionEncoder object.
Args:
size (Tuple, optional): The size of the input image. Defaults to (384, 384).
model_path (str, optional): The path to the pre-trained vision model. Defaults to "model".
return_shape (bool, optional): Whether to return the shape of the embedding. Defaults to False.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Examples::
>>> from zeta.structs import VisionEncoder
>>> encoder = VisionEncoder()
>>> embeds = encoder("image.jpg")
>>> embeds.shape
torch.Size([1, 512])
"""

def __init__(
self,
size: Tuple = (384, 384),
model_name: str = "vikhyatk/moondream0",
return_shape: bool = False,
*args,
**kwargs,
) -> None:
super().__init__()
self.size = size
self.model_name = model_name
self.return_shape = return_shape
model_path = snapshot_download(model_name)

self.model = torch.jit.load(f"{model_path}/vision.pt").to(
dtype=torch.float32
)

self.preprocess = Compose(
[
Resize(size=size, interpolation=InterpolationMode.BICUBIC),
ToImage(),
ToDtype(torch.float32, scale=True),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
*args,
]
)

def __call__(self, image: Image, *args, **kwargs) -> torch.Tensor:
"""
Processes an input image and returns its embedding.
Args:
image (Image): The input image.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
torch.Tensor: The embedding of the input image.
"""
image = Image.open(image)
with torch.no_grad():
image_vec = self.preprocess(image.convert("RGB")).unsqueeze(0)
embeds = self.model(image_vec, *args, **kwargs)

if self.return_shape:
print(f"Embedding shape: {embeds.shape}")

return embeds

0 comments on commit 5fbe1c9

Please sign in to comment.