Skip to content

Commit

Permalink
[FEAT][ImgPatchEmbed] [chore][disable_warnings_and_logs]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 21, 2023
1 parent bbb360a commit 6a550fc
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 17 deletions.
76 changes: 76 additions & 0 deletions tests/nn/modules/test_img_patch_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# FILEPATH: /Users/defalt/Desktop/Athena/research/zeta/tests/nn/modules/test_img_patch_embed.py

import pytest
from torch import nn
import torch
from zeta.nn.modules.img_patch_embed import ImgPatchEmbed


def test_class_init():
model = ImgPatchEmbed()

assert isinstance(model.proj, nn.Conv2d)
assert model.img_size == 224
assert model.patch_size == 16
assert model.num_patches == 196


def test_class_init_with_args():
model = ImgPatchEmbed(
img_size=448, patch_size=32, in_chans=1, embed_dim=512
)

assert isinstance(model.proj, nn.Conv2d)
assert model.img_size == 448
assert model.patch_size == 32
assert model.num_patches == 196
assert model.proj.in_channels == 1
assert model.proj.out_channels == 512


def test_forward():
model = ImgPatchEmbed()
x = torch.randn(1, 3, 224, 224)
out = model(x)

assert out.shape == torch.Size([1, 196, 768])


def test_forward_with_different_input():
model = ImgPatchEmbed()
x = torch.randn(2, 3, 224, 224)
out = model(x)

assert out.shape == torch.Size([2, 196, 768])


def test_forward_with_different_img_size():
model = ImgPatchEmbed(img_size=448)
x = torch.randn(1, 3, 448, 448)
out = model(x)

assert out.shape == torch.Size([1, 196, 768])


def test_forward_with_different_patch_size():
model = ImgPatchEmbed(patch_size=32)
x = torch.randn(1, 3, 224, 224)
out = model(x)

assert out.shape == torch.Size([1, 49, 768])


def test_forward_with_different_in_chans():
model = ImgPatchEmbed(in_chans=1)
x = torch.randn(1, 1, 224, 224)
out = model(x)

assert out.shape == torch.Size([1, 196, 768])


def test_forward_with_different_embed_dim():
model = ImgPatchEmbed(embed_dim=512)
x = torch.randn(1, 3, 224, 224)
out = model(x)

assert out.shape == torch.Size([1, 196, 512])
13 changes: 12 additions & 1 deletion tests/nn/modules/test_simple_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from zeta.nn.modules.simple_mamba import Mamba, ResidualBlock, RMSNorm


def test_mamba_class_init():
model = Mamba(10000, 512, 6)

Expand All @@ -13,13 +14,15 @@ def test_mamba_class_init():
assert isinstance(model.norm_f, RMSNorm)
assert isinstance(model.lm_head, nn.Linear)


def test_mamba_forward():
model = Mamba(10000, 512, 6)
x = torch.randint(0, 10000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 10000])


def test_residual_block_class_init():
block = ResidualBlock(512)

Expand All @@ -28,55 +31,63 @@ def test_residual_block_class_init():
assert isinstance(block.fc1, nn.Linear)
assert isinstance(block.fc2, nn.Linear)


def test_residual_block_forward():
block = ResidualBlock(512)
x = torch.randn(1, 50, 512)
out = block(x)

assert out.shape == torch.Size([1, 50, 512])


def test_mamba_different_vocab_size():
model = Mamba(20000, 512, 6)
x = torch.randint(0, 20000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 20000])


def test_mamba_different_dim():
model = Mamba(10000, 1024, 6)
x = torch.randint(0, 10000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 10000])


def test_mamba_different_depth():
model = Mamba(10000, 512, 12)
x = torch.randint(0, 10000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 10000])


def test_residual_block_different_dim():
block = ResidualBlock(1024)
x = torch.randn(1, 50, 1024)
out = block(x)

assert out.shape == torch.Size([1, 50, 1024])


def test_mamba_with_dropout():
model = Mamba(10000, 512, 6, dropout=0.5)
x = torch.randint(0, 10000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 10000])


def test_residual_block_with_dropout():
block = ResidualBlock(512, dropout=0.5)
x = torch.randn(1, 50, 512)
out = block(x)

assert out.shape == torch.Size([1, 50, 512])


def test_mamba_with_custom_layer():
class CustomLayer(nn.Module):
def forward(self, x):
Expand All @@ -86,4 +97,4 @@ def forward(self, x):
x = torch.randint(0, 10000, (1, 50))
out = model(x)

assert out.shape == torch.Size([1, 50, 10000])
assert out.shape == torch.Size([1, 50, 10000])
1 change: 1 addition & 0 deletions zeta/nn/biases/relative_position_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch import nn


class RelativePositionBias(nn.Module):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from zeta.nn.modules.visual_expert import VisualExpert
from zeta.nn.modules.yolo import yolo
from zeta.nn.modules.swiglu import SwiGLU, SwiGLUStacked
from zeta.nn.modules.img_patch_embed import ImgPatchEmbed

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -111,4 +112,5 @@
"AdaptiveLayerNorm",
"SwiGLU",
"SwiGLUStacked",
"ImgPatchEmbed",
]
45 changes: 45 additions & 0 deletions zeta/nn/modules/img_patch_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from torch import nn


class ImgPatchEmbed(nn.Module):
"""patch embedding module
Args:
img_size (int, optional): image size. Defaults to 224.
patch_size (int, optional): patch size. Defaults to 16.
in_chans (int, optional): input channels. Defaults to 3.
embed_dim (int, optional): embedding dimension. Defaults to 768.
Examples:
>>> x = torch.randn(1, 3, 224, 224)
>>> model = ImgPatchEmbed()
>>> model(x).shape
torch.Size([1, 196, 768])
"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)

def forward(self, x):
"""Forward
Args:
x (_type_): _description_
Returns:
_type_: _description_
"""
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
return x
8 changes: 1 addition & 7 deletions zeta/nn/modules/simple_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Optional, Union



# [HELPERS] ----------------------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
Expand Down Expand Up @@ -57,8 +56,6 @@ def forward(self, x):
return output




class Mamba(nn.Module):
def __init__(
self, vocab_size: int = None, dim: int = None, depth: int = None
Expand Down Expand Up @@ -98,7 +95,6 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss
return logits



class MambaBlock(nn.Module):
def __init__(
self,
Expand All @@ -107,7 +103,7 @@ def __init__(
depth: int,
d_state: int = 16,
expand: int = 2,
dt_rank: Union[int, str] = 'auto',
dt_rank: Union[int, str] = "auto",
d_conv: int = 4,
conv_bias: bool = True,
bias: bool = False,
Expand Down Expand Up @@ -136,7 +132,6 @@ def __init__(
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(dim_inner))
self.out_proj = nn.Linear(dim_inner, dim, bias=bias)


def forward(self, x):
"""Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].
Expand Down Expand Up @@ -260,4 +255,3 @@ def selective_scan(self, u, delta, A, B, C, D):
y = y + u * rearrange(D, "d_in -> d_in 1")

return y

67 changes: 58 additions & 9 deletions zeta/utils/disable_logging.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
# import logging
# import os
# import warnings


# def disable_warnings_and_logs():
# """
# Disables various warnings and logs.
# """
# # disable warnings
# warnings.filterwarnings("ignore")

# # disable tensorflow warnings
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

# # disable bnb warnings and others
# logging.getLogger().setLevel(logging.WARNING)

# class CustomFilter(logging.Filter):
# def filter(self, record):
# unwanted_logs = [
# "Setting ds_accelerator to mps (auto detect)",
# (
# "NOTE: Redirects are currently not supported in Windows or"
# " MacOs."
# ),
# ]
# return not any(log in record.getMessage() for log in unwanted_logs)

# # add custom filter to root logger
# logger = logging.getLogger()
# f = CustomFilter()
# logger.addFilter(f)

# # disable specific loggers
# loggers = [
# "real_accelerator",
# "torch.distributed.elastic.multiprocessing.redirects",
# ]

# for logger_name in loggers:
# logger = logging.getLogger(logger_name)
# logger.setLevel(logging.CRITICAL)


import logging
import os
import warnings


def disable_warnings_and_logs():
"""Disable warnings and logs.
Returns:
_type_: _description_
"""
Disables various warnings and logs.
"""
# disable warnings
warnings.filterwarnings("ignore")
Expand All @@ -20,12 +62,19 @@ def disable_warnings_and_logs():

class CustomFilter(logging.Filter):
def filter(self, record):
msg = "Created a temporary directory at"
return msg not in record.getMessage()
unwanted_logs = [
"Setting ds_accelerator to mps (auto detect)",
(
"NOTE: Redirects are currently not supported in Windows or"
" MacOs."
),
]
return not any(log in record.getMessage() for log in unwanted_logs)

# add custom filter to root logger
logger = logging.getLogger()
f = CustomFilter()
logger.addFilter(f)


disable_warnings_and_logs()
# disable all loggers
logging.disable(logging.CRITICAL)

0 comments on commit 6a550fc

Please sign in to comment.