Skip to content

Commit

Permalink
[FEAT][FractoralNorm}
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed May 12, 2024
1 parent b9abb28 commit 743dbba
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
12 changes: 9 additions & 3 deletions fractoral_norm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from zeta.nn import FractoralNorm # Importing the FractoralNorm class from the zeta.nn module
from zeta.nn import (
FractoralNorm,
) # Importing the FractoralNorm class from the zeta.nn module
import torch # Importing the torch module for tensor operations

# Norm
x = torch.randn(2, 3, 4) # Generating a random tensor of size (2, 3, 4)

# FractoralNorm
normed = FractoralNorm(4, 4)(x) # Applying the FractoralNorm operation to the tensor x
normed = FractoralNorm(4, 4)(
x
) # Applying the FractoralNorm operation to the tensor x

print(normed) # Printing the size of the resulting tensor, which should be torch.Size([2, 3, 4])
print(
normed
) # Printing the size of the resulting tensor, which should be torch.Size([2, 3, 4])
1 change: 1 addition & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
from zeta.nn.modules.kan import KAN
from zeta.nn.modules.layer_scale import LayerScale
from zeta.nn.modules.fractoral_norm import FractoralNorm

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
# from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding
Expand Down
1 change: 1 addition & 0 deletions zeta/nn/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from zeta.nn.modules.glu import GLU
from zeta.nn.modules.swiglu import SwiGLU
from typing import Optional

# from zeta.experimental.triton.triton_modules.linear_proj import LinearTriton


Expand Down
7 changes: 4 additions & 3 deletions zeta/nn/modules/layer_scale.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from torch.nn import Module
import torch
import torch
from torch import nn, Tensor


class LayerScale(Module):
"""
Applies layer scaling to the output of a given module.
Expand All @@ -17,7 +18,7 @@ class LayerScale(Module):
"""

def __init__(self, fn: Module, dim, init_value=0.):
def __init__(self, fn: Module, dim, init_value=0.0):
super().__init__()
self.fn = fn
self.gamma = nn.Parameter(torch.ones(dim) * init_value)
Expand All @@ -29,4 +30,4 @@ def forward(self, x, **kwargs):
return out * self.gamma

out, *rest = out
return out * self.gamma, *rest
return out * self.gamma, *rest

0 comments on commit 743dbba

Please sign in to comment.