From 0e08a62ccbd3a05cd1498fb9c5ba6b97ce2b7e80 Mon Sep 17 00:00:00 2001 From: Kye Date: Mon, 25 Dec 2023 14:17:00 -0500 Subject: [PATCH] [FEAT][DenseBlock] [DualPathBlock] [FeedbackBlock] [HighwayLayer] [MultiScaleBlock] [RecursiveBlock] [SkipConnection] --- pyproject.toml | 2 +- zeta/nn/modules/__init__.py | 13 ++++++++- zeta/nn/modules/dense_connect.py | 28 +++++++++++++++++++ zeta/nn/modules/dual_path_block.py | 27 ++++++++++++++++++ zeta/nn/modules/feedback_block.py | 31 +++++++++++++++++++++ zeta/nn/modules/highway_layer.py | 30 ++++++++++++++++++++ zeta/nn/modules/multi_scale_block.py | 28 +++++++++++++++++++ zeta/nn/modules/recursive_block.py | 32 +++++++++++++++++++++ zeta/nn/modules/skip_connect.py | 20 ++++++++++++++ zeta/nn/modules/test_dense_connect.py | 40 +++++++++++++++++++++++++++ 10 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 zeta/nn/modules/dense_connect.py create mode 100644 zeta/nn/modules/dual_path_block.py create mode 100644 zeta/nn/modules/feedback_block.py create mode 100644 zeta/nn/modules/highway_layer.py create mode 100644 zeta/nn/modules/multi_scale_block.py create mode 100644 zeta/nn/modules/recursive_block.py create mode 100644 zeta/nn/modules/skip_connect.py create mode 100644 zeta/nn/modules/test_dense_connect.py diff --git a/pyproject.toml b/pyproject.toml index cd888710..c6493559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "1.2.5" +version = "1.2.6" description = "Transformers at zeta scales" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 3f33195e..e6dad4b9 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -47,6 +47,12 @@ from zeta.nn.modules.yolo import yolo from zeta.nn.modules.swiglu import SwiGLU, SwiGLUStacked from zeta.nn.modules.img_patch_embed import ImgPatchEmbed +from zeta.nn.modules.dense_connect import DenseBlock +from zeta.nn.modules.highway_layer import HighwayLayer +from zeta.nn.modules.multi_scale_block import MultiScaleBlock +from zeta.nn.modules.feedback_block import FeedbackBlock +from zeta.nn.modules.dual_path_block import DualPathBlock +from zeta.nn.modules.recursive_block import RecursiveBlock # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -60,7 +66,6 @@ # from zeta.nn.modules.transformations import image_transform # from zeta.nn.modules.squeeze_excitation import SqueezeExcitation # from zeta.nn.modules.clex import Clex - __all__ = [ "CNNNew", "CombinedLinear", @@ -113,4 +118,10 @@ "SwiGLU", "SwiGLUStacked", "ImgPatchEmbed", + "DenseBlock", + "HighwayLayer", + "MultiScaleBlock", + "FeedbackBlock", + "DualPathBlock", + "RecursiveBlock", ] diff --git a/zeta/nn/modules/dense_connect.py b/zeta/nn/modules/dense_connect.py new file mode 100644 index 00000000..ce1c2923 --- /dev/null +++ b/zeta/nn/modules/dense_connect.py @@ -0,0 +1,28 @@ +import torch +from torch import nn + + +class DenseBlock(nn.Module): + def __init__(self, submodule, *args, **kwargs): + """ + Initializes a DenseBlock module. + + Args: + submodule (nn.Module): The submodule to be applied in the forward pass. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__() + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the DenseBlock module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying the DenseBlock operation. + """ + return torch.cat([x, self.submodule(x)], dim=1) diff --git a/zeta/nn/modules/dual_path_block.py b/zeta/nn/modules/dual_path_block.py new file mode 100644 index 00000000..1d9241c9 --- /dev/null +++ b/zeta/nn/modules/dual_path_block.py @@ -0,0 +1,27 @@ +from torch import nn + + +class DualPathBlock(nn.Module): + def __init__(self, submodule1, submodule2): + """ + DualPathBlock is a module that combines the output of two submodules by element-wise addition. + + Args: + submodule1 (nn.Module): The first submodule. + submodule2 (nn.Module): The second submodule. + """ + super().__init__() + self.submodule1 = submodule1 + self.submodule2 = submodule2 + + def forward(self, x): + """ + Forward pass of the DualPathBlock. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor obtained by adding the outputs of submodule1 and submodule2. + """ + return self.submodule1(x) + self.submodule2(x) diff --git a/zeta/nn/modules/feedback_block.py b/zeta/nn/modules/feedback_block.py new file mode 100644 index 00000000..82fa4dd0 --- /dev/null +++ b/zeta/nn/modules/feedback_block.py @@ -0,0 +1,31 @@ +import torch +from torch import nn + + +class FeedbackBlock(nn.Module): + def __init__(self, submodule): + """ + Initializes a FeedbackBlock module. + + Args: + submodule (nn.Module): The submodule to be used within the FeedbackBlock. + """ + super().__init__() + self.submodule = submodule + + def forward(self, x: torch.Tensor, feedback, *args, **kwargs): + """ + Performs a forward pass through the FeedbackBlock. + + Args: + x (torch.Tensor): The input tensor. + feedback: The feedback tensor. + *args: Additional positional arguments to be passed to the submodule's forward method. + **kwargs: Additional keyword arguments to be passed to the submodule's forward method. + + Returns: + torch.Tensor: The output tensor after passing through the FeedbackBlock. + """ + if feedback is not None: + x = x + feedback + return self.submodule(x, *args, **kwargs) diff --git a/zeta/nn/modules/highway_layer.py b/zeta/nn/modules/highway_layer.py new file mode 100644 index 00000000..3802f3e2 --- /dev/null +++ b/zeta/nn/modules/highway_layer.py @@ -0,0 +1,30 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class HighwayLayer(nn.Module): + def __init__(self, dim): + """ + Initializes a HighwayLayer instance. + + Args: + dim (int): The input and output dimension of the layer. + """ + super().__init__() + self.normal_layer = nn.Linear(dim, dim) + self.gate = nn.Linear(dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Performs a forward pass through the HighwayLayer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor. + """ + normal_result = F.relu(self.normal_layer(x)) + gate = torch.sigmoid(self.gate(x)) + return gate * normal_result + (1 - gate) * x diff --git a/zeta/nn/modules/multi_scale_block.py b/zeta/nn/modules/multi_scale_block.py new file mode 100644 index 00000000..fc686e2a --- /dev/null +++ b/zeta/nn/modules/multi_scale_block.py @@ -0,0 +1,28 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class MultiScaleBlock(nn.Module): + """ + A module that applies a given submodule to the input tensor at multiple scales. + + Args: + module (nn.Module): The submodule to apply. + + Returns: + torch.Tensor: The output tensor after applying the submodule at multiple scales. + """ + + def __init__(self, module): + super().__init__() + self.submodule = module + + def forward(self, x: torch.Tensor, *args, **kwargs): + x1 = F.interpolate(x, scale_factor=0.5, *args, **kwargs) + x2 = F.interpolate(x, scale_factor=2.0, *args, **kwargs) + return ( + self.submodule(x) + + F.interpolate(self.submodule(x1), size=x.shape[2:]) + + F.interpolate(self.submodule(x2), size=x.shape[2:]) + ) diff --git a/zeta/nn/modules/recursive_block.py b/zeta/nn/modules/recursive_block.py new file mode 100644 index 00000000..f1ab54de --- /dev/null +++ b/zeta/nn/modules/recursive_block.py @@ -0,0 +1,32 @@ +import torch +from torch import nn + + +class RecursiveBlock(nn.Module): + def __init__(self, modules, iters, *args, **kwargs): + """ + Initializes a RecursiveBlock module. + + Args: + modules (nn.Module): The module to be applied recursively. + iters (int): The number of iterations to apply the module. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__() + self.modules = modules + self.iters = iters + + def forward(self, x: torch.Tensor): + """ + Forward pass of the RecursiveBlock module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying the module recursively. + """ + for _ in range(self.iters): + x = self.modules(x) + return x diff --git a/zeta/nn/modules/skip_connect.py b/zeta/nn/modules/skip_connect.py new file mode 100644 index 00000000..21d4c50b --- /dev/null +++ b/zeta/nn/modules/skip_connect.py @@ -0,0 +1,20 @@ +import torch +from torch import nn + + +class SkipConnection(nn.Module): + def __init__(self, submodule): + super().__init__() + self.submodule = submodule + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the SkipConnection module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after adding the input tensor with the submodule output. + """ + return x + self.submodule(x) diff --git a/zeta/nn/modules/test_dense_connect.py b/zeta/nn/modules/test_dense_connect.py new file mode 100644 index 00000000..0cf6d5d8 --- /dev/null +++ b/zeta/nn/modules/test_dense_connect.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn +import unittest + +from your_module import DenseBlock + + +class DenseBlockTestCase(unittest.TestCase): + def setUp(self): + self.submodule = nn.Linear(10, 5) + self.dense_block = DenseBlock(self.submodule) + + def test_forward(self): + x = torch.randn(32, 10) + output = self.dense_block(x) + + self.assertEqual(output.shape, (32, 15)) # Check output shape + self.assertTrue( + torch.allclose(output[:, :10], x) + ) # Check if input is preserved + self.assertTrue( + torch.allclose(output[:, 10:], self.submodule(x)) + ) # Check submodule output + + def test_initialization(self): + self.assertEqual( + self.dense_block.submodule, self.submodule + ) # Check submodule assignment + + def test_docstrings(self): + self.assertIsNotNone( + DenseBlock.__init__.__doc__ + ) # Check if __init__ has a docstring + self.assertIsNotNone( + DenseBlock.forward.__doc__ + ) # Check if forward has a docstring + + +if __name__ == "__main__": + unittest.main()