Skip to content

Commit

Permalink
[FEAT][DenseBlock] [DualPathBlock] [FeedbackBlock] [HighwayLayer] [Mu…
Browse files Browse the repository at this point in the history
…ltiScaleBlock] [RecursiveBlock] [SkipConnection]
  • Loading branch information
Kye committed Dec 25, 2023
1 parent 5c5ad27 commit 0e08a62
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "1.2.5"
version = "1.2.6"
description = "Transformers at zeta scales"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
13 changes: 12 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@
from zeta.nn.modules.yolo import yolo
from zeta.nn.modules.swiglu import SwiGLU, SwiGLUStacked
from zeta.nn.modules.img_patch_embed import ImgPatchEmbed
from zeta.nn.modules.dense_connect import DenseBlock
from zeta.nn.modules.highway_layer import HighwayLayer
from zeta.nn.modules.multi_scale_block import MultiScaleBlock
from zeta.nn.modules.feedback_block import FeedbackBlock
from zeta.nn.modules.dual_path_block import DualPathBlock
from zeta.nn.modules.recursive_block import RecursiveBlock

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand All @@ -60,7 +66,6 @@
# from zeta.nn.modules.transformations import image_transform
# from zeta.nn.modules.squeeze_excitation import SqueezeExcitation
# from zeta.nn.modules.clex import Clex

__all__ = [
"CNNNew",
"CombinedLinear",
Expand Down Expand Up @@ -113,4 +118,10 @@
"SwiGLU",
"SwiGLUStacked",
"ImgPatchEmbed",
"DenseBlock",
"HighwayLayer",
"MultiScaleBlock",
"FeedbackBlock",
"DualPathBlock",
"RecursiveBlock",
]
28 changes: 28 additions & 0 deletions zeta/nn/modules/dense_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch import nn


class DenseBlock(nn.Module):
def __init__(self, submodule, *args, **kwargs):
"""
Initializes a DenseBlock module.
Args:
submodule (nn.Module): The submodule to be applied in the forward pass.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__()
self.submodule = submodule

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the DenseBlock module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying the DenseBlock operation.
"""
return torch.cat([x, self.submodule(x)], dim=1)
27 changes: 27 additions & 0 deletions zeta/nn/modules/dual_path_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torch import nn


class DualPathBlock(nn.Module):
def __init__(self, submodule1, submodule2):
"""
DualPathBlock is a module that combines the output of two submodules by element-wise addition.
Args:
submodule1 (nn.Module): The first submodule.
submodule2 (nn.Module): The second submodule.
"""
super().__init__()
self.submodule1 = submodule1
self.submodule2 = submodule2

def forward(self, x):
"""
Forward pass of the DualPathBlock.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor obtained by adding the outputs of submodule1 and submodule2.
"""
return self.submodule1(x) + self.submodule2(x)
31 changes: 31 additions & 0 deletions zeta/nn/modules/feedback_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from torch import nn


class FeedbackBlock(nn.Module):
def __init__(self, submodule):
"""
Initializes a FeedbackBlock module.
Args:
submodule (nn.Module): The submodule to be used within the FeedbackBlock.
"""
super().__init__()
self.submodule = submodule

def forward(self, x: torch.Tensor, feedback, *args, **kwargs):
"""
Performs a forward pass through the FeedbackBlock.
Args:
x (torch.Tensor): The input tensor.
feedback: The feedback tensor.
*args: Additional positional arguments to be passed to the submodule's forward method.
**kwargs: Additional keyword arguments to be passed to the submodule's forward method.
Returns:
torch.Tensor: The output tensor after passing through the FeedbackBlock.
"""
if feedback is not None:
x = x + feedback
return self.submodule(x, *args, **kwargs)
30 changes: 30 additions & 0 deletions zeta/nn/modules/highway_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch import nn
import torch.nn.functional as F


class HighwayLayer(nn.Module):
def __init__(self, dim):
"""
Initializes a HighwayLayer instance.
Args:
dim (int): The input and output dimension of the layer.
"""
super().__init__()
self.normal_layer = nn.Linear(dim, dim)
self.gate = nn.Linear(dim, dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a forward pass through the HighwayLayer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
normal_result = F.relu(self.normal_layer(x))
gate = torch.sigmoid(self.gate(x))
return gate * normal_result + (1 - gate) * x
28 changes: 28 additions & 0 deletions zeta/nn/modules/multi_scale_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch import nn
import torch.nn.functional as F


class MultiScaleBlock(nn.Module):
"""
A module that applies a given submodule to the input tensor at multiple scales.
Args:
module (nn.Module): The submodule to apply.
Returns:
torch.Tensor: The output tensor after applying the submodule at multiple scales.
"""

def __init__(self, module):
super().__init__()
self.submodule = module

def forward(self, x: torch.Tensor, *args, **kwargs):
x1 = F.interpolate(x, scale_factor=0.5, *args, **kwargs)
x2 = F.interpolate(x, scale_factor=2.0, *args, **kwargs)
return (
self.submodule(x)
+ F.interpolate(self.submodule(x1), size=x.shape[2:])
+ F.interpolate(self.submodule(x2), size=x.shape[2:])
)
32 changes: 32 additions & 0 deletions zeta/nn/modules/recursive_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn


class RecursiveBlock(nn.Module):
def __init__(self, modules, iters, *args, **kwargs):
"""
Initializes a RecursiveBlock module.
Args:
modules (nn.Module): The module to be applied recursively.
iters (int): The number of iterations to apply the module.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__()
self.modules = modules
self.iters = iters

def forward(self, x: torch.Tensor):
"""
Forward pass of the RecursiveBlock module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying the module recursively.
"""
for _ in range(self.iters):
x = self.modules(x)
return x
20 changes: 20 additions & 0 deletions zeta/nn/modules/skip_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from torch import nn


class SkipConnection(nn.Module):
def __init__(self, submodule):
super().__init__()
self.submodule = submodule

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the SkipConnection module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after adding the input tensor with the submodule output.
"""
return x + self.submodule(x)
40 changes: 40 additions & 0 deletions zeta/nn/modules/test_dense_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
import unittest

from your_module import DenseBlock


class DenseBlockTestCase(unittest.TestCase):
def setUp(self):
self.submodule = nn.Linear(10, 5)
self.dense_block = DenseBlock(self.submodule)

def test_forward(self):
x = torch.randn(32, 10)
output = self.dense_block(x)

self.assertEqual(output.shape, (32, 15)) # Check output shape
self.assertTrue(
torch.allclose(output[:, :10], x)
) # Check if input is preserved
self.assertTrue(
torch.allclose(output[:, 10:], self.submodule(x))
) # Check submodule output

def test_initialization(self):
self.assertEqual(
self.dense_block.submodule, self.submodule
) # Check submodule assignment

def test_docstrings(self):
self.assertIsNotNone(
DenseBlock.__init__.__doc__
) # Check if __init__ has a docstring
self.assertIsNotNone(
DenseBlock.forward.__doc__
) # Check if forward has a docstring


if __name__ == "__main__":
unittest.main()

0 comments on commit 0e08a62

Please sign in to comment.