-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT][DenseBlock] [DualPathBlock] [FeedbackBlock] [HighwayLayer] [Mu…
…ltiScaleBlock] [RecursiveBlock] [SkipConnection]
- Loading branch information
Kye
committed
Dec 25, 2023
1 parent
5c5ad27
commit 0e08a62
Showing
10 changed files
with
249 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "zetascale" | ||
version = "1.2.5" | ||
version = "1.2.6" | ||
description = "Transformers at zeta scales" | ||
authors = ["Zeta Team <[email protected]>"] | ||
license = "MIT" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class DenseBlock(nn.Module): | ||
def __init__(self, submodule, *args, **kwargs): | ||
""" | ||
Initializes a DenseBlock module. | ||
Args: | ||
submodule (nn.Module): The submodule to be applied in the forward pass. | ||
*args: Variable length argument list. | ||
**kwargs: Arbitrary keyword arguments. | ||
""" | ||
super().__init__() | ||
self.submodule = submodule | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass of the DenseBlock module. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
Returns: | ||
torch.Tensor: The output tensor after applying the DenseBlock operation. | ||
""" | ||
return torch.cat([x, self.submodule(x)], dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torch import nn | ||
|
||
|
||
class DualPathBlock(nn.Module): | ||
def __init__(self, submodule1, submodule2): | ||
""" | ||
DualPathBlock is a module that combines the output of two submodules by element-wise addition. | ||
Args: | ||
submodule1 (nn.Module): The first submodule. | ||
submodule2 (nn.Module): The second submodule. | ||
""" | ||
super().__init__() | ||
self.submodule1 = submodule1 | ||
self.submodule2 = submodule2 | ||
|
||
def forward(self, x): | ||
""" | ||
Forward pass of the DualPathBlock. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
Returns: | ||
torch.Tensor: The output tensor obtained by adding the outputs of submodule1 and submodule2. | ||
""" | ||
return self.submodule1(x) + self.submodule2(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class FeedbackBlock(nn.Module): | ||
def __init__(self, submodule): | ||
""" | ||
Initializes a FeedbackBlock module. | ||
Args: | ||
submodule (nn.Module): The submodule to be used within the FeedbackBlock. | ||
""" | ||
super().__init__() | ||
self.submodule = submodule | ||
|
||
def forward(self, x: torch.Tensor, feedback, *args, **kwargs): | ||
""" | ||
Performs a forward pass through the FeedbackBlock. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
feedback: The feedback tensor. | ||
*args: Additional positional arguments to be passed to the submodule's forward method. | ||
**kwargs: Additional keyword arguments to be passed to the submodule's forward method. | ||
Returns: | ||
torch.Tensor: The output tensor after passing through the FeedbackBlock. | ||
""" | ||
if feedback is not None: | ||
x = x + feedback | ||
return self.submodule(x, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class HighwayLayer(nn.Module): | ||
def __init__(self, dim): | ||
""" | ||
Initializes a HighwayLayer instance. | ||
Args: | ||
dim (int): The input and output dimension of the layer. | ||
""" | ||
super().__init__() | ||
self.normal_layer = nn.Linear(dim, dim) | ||
self.gate = nn.Linear(dim, dim) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Performs a forward pass through the HighwayLayer. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
Returns: | ||
torch.Tensor: The output tensor. | ||
""" | ||
normal_result = F.relu(self.normal_layer(x)) | ||
gate = torch.sigmoid(self.gate(x)) | ||
return gate * normal_result + (1 - gate) * x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MultiScaleBlock(nn.Module): | ||
""" | ||
A module that applies a given submodule to the input tensor at multiple scales. | ||
Args: | ||
module (nn.Module): The submodule to apply. | ||
Returns: | ||
torch.Tensor: The output tensor after applying the submodule at multiple scales. | ||
""" | ||
|
||
def __init__(self, module): | ||
super().__init__() | ||
self.submodule = module | ||
|
||
def forward(self, x: torch.Tensor, *args, **kwargs): | ||
x1 = F.interpolate(x, scale_factor=0.5, *args, **kwargs) | ||
x2 = F.interpolate(x, scale_factor=2.0, *args, **kwargs) | ||
return ( | ||
self.submodule(x) | ||
+ F.interpolate(self.submodule(x1), size=x.shape[2:]) | ||
+ F.interpolate(self.submodule(x2), size=x.shape[2:]) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class RecursiveBlock(nn.Module): | ||
def __init__(self, modules, iters, *args, **kwargs): | ||
""" | ||
Initializes a RecursiveBlock module. | ||
Args: | ||
modules (nn.Module): The module to be applied recursively. | ||
iters (int): The number of iterations to apply the module. | ||
*args: Variable length argument list. | ||
**kwargs: Arbitrary keyword arguments. | ||
""" | ||
super().__init__() | ||
self.modules = modules | ||
self.iters = iters | ||
|
||
def forward(self, x: torch.Tensor): | ||
""" | ||
Forward pass of the RecursiveBlock module. | ||
Args: | ||
x (torch.Tensor): The input tensor. | ||
Returns: | ||
torch.Tensor: The output tensor after applying the module recursively. | ||
""" | ||
for _ in range(self.iters): | ||
x = self.modules(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class SkipConnection(nn.Module): | ||
def __init__(self, submodule): | ||
super().__init__() | ||
self.submodule = submodule | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Forward pass of the SkipConnection module. | ||
Args: | ||
x (torch.Tensor): Input tensor. | ||
Returns: | ||
torch.Tensor: Output tensor after adding the input tensor with the submodule output. | ||
""" | ||
return x + self.submodule(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import torch.nn as nn | ||
import unittest | ||
|
||
from your_module import DenseBlock | ||
|
||
|
||
class DenseBlockTestCase(unittest.TestCase): | ||
def setUp(self): | ||
self.submodule = nn.Linear(10, 5) | ||
self.dense_block = DenseBlock(self.submodule) | ||
|
||
def test_forward(self): | ||
x = torch.randn(32, 10) | ||
output = self.dense_block(x) | ||
|
||
self.assertEqual(output.shape, (32, 15)) # Check output shape | ||
self.assertTrue( | ||
torch.allclose(output[:, :10], x) | ||
) # Check if input is preserved | ||
self.assertTrue( | ||
torch.allclose(output[:, 10:], self.submodule(x)) | ||
) # Check submodule output | ||
|
||
def test_initialization(self): | ||
self.assertEqual( | ||
self.dense_block.submodule, self.submodule | ||
) # Check submodule assignment | ||
|
||
def test_docstrings(self): | ||
self.assertIsNotNone( | ||
DenseBlock.__init__.__doc__ | ||
) # Check if __init__ has a docstring | ||
self.assertIsNotNone( | ||
DenseBlock.forward.__doc__ | ||
) # Check if forward has a docstring | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |