-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Dec 30, 2023
1 parent
57f1e82
commit 5fbe1c9
Showing
7 changed files
with
233 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import torch | ||
import torch.nn as nn | ||
from zeta.nn.modules.adaptive_rmsnorm import AdaptiveRMSNorm | ||
|
||
|
||
def test_adaptive_rmsnorm_init(): | ||
arn = AdaptiveRMSNorm(10, dim_cond=5) | ||
assert isinstance(arn, AdaptiveRMSNorm) | ||
assert arn.dim_cond == 5 | ||
assert arn.channel_first == False | ||
assert arn.scale == 10**0.5 | ||
assert isinstance(arn.to_gamma, nn.Linear) | ||
assert arn.to_bias is None | ||
|
||
|
||
def test_adaptive_rmsnorm_init_with_bias(): | ||
arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True) | ||
assert isinstance(arn.to_bias, nn.Linear) | ||
|
||
|
||
def test_adaptive_rmsnorm_forward(): | ||
arn = AdaptiveRMSNorm(10, dim_cond=5) | ||
x = torch.randn(2, 10) | ||
cond = torch.randn(2, 5) | ||
output = arn.forward(x, cond=cond) | ||
assert output.shape == (2, 10) | ||
|
||
|
||
def test_adaptive_rmsnorm_forward_with_bias(): | ||
arn = AdaptiveRMSNorm(10, dim_cond=5, bias=True) | ||
x = torch.randn(2, 10) | ||
cond = torch.randn(2, 5) | ||
output = arn.forward(x, cond=cond) | ||
assert output.shape == (2, 10) | ||
|
||
|
||
def test_adaptive_rmsnorm_forward_channel_first(): | ||
arn = AdaptiveRMSNorm(10, dim_cond=5, channel_first=True) | ||
x = torch.randn(2, 10, 3, 3) | ||
cond = torch.randn(2, 5) | ||
output = arn.forward(x, cond=cond) | ||
assert output.shape == (2, 10, 3, 3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import torch | ||
from zeta.structs.simple_vision_encoder import SimpleVisionEncoder | ||
|
||
|
||
def test_simple_vision_encoder_init(): | ||
sve = SimpleVisionEncoder() | ||
assert sve.size == (384, 384) | ||
assert sve.model_name == "vikhyatk/moondream0" | ||
assert sve.return_shape == False | ||
assert isinstance(sve.model, torch.jit.ScriptModule) | ||
assert sve.preprocess.transforms[-1].scale == True | ||
assert sve.preprocess.transforms[-1].dtype == torch.float32 | ||
|
||
|
||
def test_simple_vision_encoder_init_custom_size(): | ||
sve = SimpleVisionEncoder(size=(512, 512)) | ||
assert sve.size == (512, 512) | ||
|
||
|
||
def test_simple_vision_encoder_init_custom_model_name(): | ||
sve = SimpleVisionEncoder(model_name="custom/model") | ||
assert sve.model_name == "custom/model" | ||
|
||
|
||
def test_simple_vision_encoder_init_return_shape(): | ||
sve = SimpleVisionEncoder(return_shape=True) | ||
assert sve.return_shape == True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -69,4 +69,3 @@ def forward(self, fmap): | |
|
||
out = self.nonlin(out) | ||
return self.to_out(out) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from torch import nn, Tensor | ||
from beartype import beartype | ||
import torch.nn.functional as F | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def append_dims(t, ndims: int): | ||
return t.reshape(*t.shape, *((1,) * ndims)) | ||
|
||
|
||
class AdaptiveRMSNorm(nn.Module): | ||
""" | ||
Adaptive Root Mean Square Normalization (RMSNorm) module. | ||
Args: | ||
dim (int): The input dimension. | ||
dim_cond (int): The dimension of the conditioning tensor. | ||
channel_first (bool, optional): Whether the input has channels as the first dimension. Defaults to False. | ||
images (bool, optional): Whether the input represents images. Defaults to False. | ||
bias (bool, optional): Whether to include a bias term. Defaults to False. | ||
""" | ||
|
||
def __init__( | ||
self, dim, *, dim_cond, channel_first=False, images=False, bias=False | ||
): | ||
super().__init__() | ||
|
||
self.dim_cond = dim_cond | ||
self.channel_first = channel_first | ||
self.scale = dim**0.5 | ||
|
||
self.to_gamma = nn.Linear(dim_cond, dim) | ||
self.to_bias = nn.Linear(dim_cond, dim) if bias else None | ||
|
||
nn.init.zeros_(self.to_gamma.weight) | ||
nn.init.ones_(self.to_gamma.bias) | ||
|
||
if bias: | ||
nn.init.zeros_(self.to_bias.weight) | ||
nn.init.zeros_(self.to_bias.bias) | ||
|
||
@beartype | ||
def forward(self, x: Tensor, *, cond: Tensor): | ||
""" | ||
Forward pass of the AdaptiveRMSNorm module. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
cond (torch.Tensor): The conditioning tensor. | ||
Returns: | ||
torch.Tensor: The normalized and conditioned output tensor. | ||
""" | ||
batch = x.shape[0] | ||
assert cond.shape == (batch, self.dim_cond) | ||
|
||
gamma = self.to_gamma(cond) | ||
|
||
bias = 0.0 | ||
if exists(self.to_bias): | ||
bias = self.to_bias(cond) | ||
|
||
if self.channel_first: | ||
gamma = append_dims(gamma, x.ndim - 2) | ||
|
||
if exists(self.to_bias): | ||
bias = append_dims(bias, x.ndim - 2) | ||
|
||
return ( | ||
F.normalize(x, dim=(1 if self.channel_first else -1)) | ||
* self.scale | ||
* gamma | ||
+ bias | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from PIL import Image | ||
from torchvision.transforms.v2 import ( | ||
Compose, | ||
Resize, | ||
InterpolationMode, | ||
ToImage, | ||
ToDtype, | ||
Normalize, | ||
) | ||
from typing import Tuple | ||
from torch import nn | ||
from huggingface_hub import snapshot_download | ||
|
||
|
||
class VisionEncoder(nn.Module): | ||
""" | ||
Initializes a VisionEncoder object. | ||
Args: | ||
size (Tuple, optional): The size of the input image. Defaults to (384, 384). | ||
model_path (str, optional): The path to the pre-trained vision model. Defaults to "model". | ||
return_shape (bool, optional): Whether to return the shape of the embedding. Defaults to False. | ||
*args: Variable length argument list. | ||
**kwargs: Arbitrary keyword arguments. | ||
Examples:: | ||
>>> from zeta.structs import VisionEncoder | ||
>>> encoder = VisionEncoder() | ||
>>> embeds = encoder("image.jpg") | ||
>>> embeds.shape | ||
torch.Size([1, 512]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
size: Tuple = (384, 384), | ||
model_name: str = "vikhyatk/moondream0", | ||
return_shape: bool = False, | ||
*args, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
self.size = size | ||
self.model_name = model_name | ||
self.return_shape = return_shape | ||
model_path = snapshot_download(model_name) | ||
|
||
self.model = torch.jit.load(f"{model_path}/vision.pt").to( | ||
dtype=torch.float32 | ||
) | ||
|
||
self.preprocess = Compose( | ||
[ | ||
Resize(size=size, interpolation=InterpolationMode.BICUBIC), | ||
ToImage(), | ||
ToDtype(torch.float32, scale=True), | ||
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | ||
*args, | ||
] | ||
) | ||
|
||
def __call__(self, image: Image, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Processes an input image and returns its embedding. | ||
Args: | ||
image (Image): The input image. | ||
*args: Variable length argument list. | ||
**kwargs: Arbitrary keyword arguments. | ||
Returns: | ||
torch.Tensor: The embedding of the input image. | ||
""" | ||
image = Image.open(image) | ||
with torch.no_grad(): | ||
image_vec = self.preprocess(image.convert("RGB")).unsqueeze(0) | ||
embeds = self.model(image_vec, *args, **kwargs) | ||
|
||
if self.return_shape: | ||
print(f"Embedding shape: {embeds.shape}") | ||
|
||
return embeds |