diff --git a/zeta/nn/modules/fractoral_norm.py b/zeta/nn/modules/fractoral_norm.py new file mode 100644 index 00000000..bf4ccf84 --- /dev/null +++ b/zeta/nn/modules/fractoral_norm.py @@ -0,0 +1,32 @@ +from torch import nn, Tensor + + +class FractoralNorm(nn.Module): + """ + FractoralNorm module applies LayerNorm to the input tensor multiple times in a row. + + Args: + num_features (int): Number of features in the input tensor. + depth (int): Number of times to apply LayerNorm. + """ + + def __init__(self, num_features: int, depth: int): + super().__init__() + + self.layers = nn.ModuleList( + [nn.LayerNorm(num_features) for _ in range(depth)] + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the FractoralNorm module. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor after applying LayerNorm multiple times. + """ + for layer in self.layers: + x = layer(x) + return x