Skip to content

Commit

Permalink
[FEAT][FractoralNorm]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 7, 2024
1 parent af1b95c commit 2e6e0b6
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions zeta/nn/modules/fractoral_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from torch import nn, Tensor


class FractoralNorm(nn.Module):
"""
FractoralNorm module applies LayerNorm to the input tensor multiple times in a row.
Args:
num_features (int): Number of features in the input tensor.
depth (int): Number of times to apply LayerNorm.
"""

def __init__(self, num_features: int, depth: int):
super().__init__()

self.layers = nn.ModuleList(
[nn.LayerNorm(num_features) for _ in range(depth)]
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the FractoralNorm module.
Args:
x (Tensor): Input tensor.
Returns:
Tensor: Output tensor after applying LayerNorm multiple times.
"""
for layer in self.layers:
x = layer(x)
return x

0 comments on commit 2e6e0b6

Please sign in to comment.