From 743dbbaf06d083778f901c9df451b70f855119ad Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 11 May 2024 22:56:23 -0700 Subject: [PATCH] [FEAT][FractoralNorm} --- fractoral_norm.py | 12 +++++++++--- zeta/nn/modules/__init__.py | 1 + zeta/nn/modules/feedforward.py | 1 + zeta/nn/modules/layer_scale.py | 7 ++++--- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/fractoral_norm.py b/fractoral_norm.py index 832509e5..e9720a5a 100644 --- a/fractoral_norm.py +++ b/fractoral_norm.py @@ -1,10 +1,16 @@ -from zeta.nn import FractoralNorm # Importing the FractoralNorm class from the zeta.nn module +from zeta.nn import ( + FractoralNorm, +) # Importing the FractoralNorm class from the zeta.nn module import torch # Importing the torch module for tensor operations # Norm x = torch.randn(2, 3, 4) # Generating a random tensor of size (2, 3, 4) # FractoralNorm -normed = FractoralNorm(4, 4)(x) # Applying the FractoralNorm operation to the tensor x +normed = FractoralNorm(4, 4)( + x +) # Applying the FractoralNorm operation to the tensor x -print(normed) # Printing the size of the resulting tensor, which should be torch.Size([2, 3, 4]) \ No newline at end of file +print( + normed +) # Printing the size of the resulting tensor, which should be torch.Size([2, 3, 4]) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index fc2bf595..639cfc9f 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -216,6 +216,7 @@ from zeta.nn.modules.kan import KAN from zeta.nn.modules.layer_scale import LayerScale from zeta.nn.modules.fractoral_norm import FractoralNorm + # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features # from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding diff --git a/zeta/nn/modules/feedforward.py b/zeta/nn/modules/feedforward.py index bee66c71..18925ff2 100644 --- a/zeta/nn/modules/feedforward.py +++ b/zeta/nn/modules/feedforward.py @@ -3,6 +3,7 @@ from zeta.nn.modules.glu import GLU from zeta.nn.modules.swiglu import SwiGLU from typing import Optional + # from zeta.experimental.triton.triton_modules.linear_proj import LinearTriton diff --git a/zeta/nn/modules/layer_scale.py b/zeta/nn/modules/layer_scale.py index 58e5083c..6552394a 100644 --- a/zeta/nn/modules/layer_scale.py +++ b/zeta/nn/modules/layer_scale.py @@ -1,7 +1,8 @@ from torch.nn import Module -import torch +import torch from torch import nn, Tensor + class LayerScale(Module): """ Applies layer scaling to the output of a given module. @@ -17,7 +18,7 @@ class LayerScale(Module): """ - def __init__(self, fn: Module, dim, init_value=0.): + def __init__(self, fn: Module, dim, init_value=0.0): super().__init__() self.fn = fn self.gamma = nn.Parameter(torch.ones(dim) * init_value) @@ -29,4 +30,4 @@ def forward(self, x, **kwargs): return out * self.gamma out, *rest = out - return out * self.gamma, *rest \ No newline at end of file + return out * self.gamma, *rest