Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 23, 2024
1 parent 0114dc2 commit 1f8d0fd
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.1.6"
version = "2.1.7"
description = "Transformers at zeta scales"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
Expand Down
1 change: 1 addition & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
from zeta.nn.modules.ws_conv2d import WSConv2d
from zeta.nn.modules.yolo import yolo
from zeta.nn.modules.palo_ldp import PaloLDP

# from zeta.nn.modules.g_shard_moe import (
# Top1Gate,
# Top2Gate,
Expand Down
7 changes: 2 additions & 5 deletions zeta/nn/modules/palo_ldp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

from torch import Tensor, nn
from zeta.utils.log_pytorch_op import log_torch_op


class PaloLDP(nn.Module):
"""
Implementation of the PaloLDP module.
Expand Down Expand Up @@ -75,7 +75,6 @@ def forward(self, x: Tensor) -> Tensor:
x = self.pointwise_conv(x)
print(x.shape) # torch.Size([2, 1, 4, 4]


# Depthwise convolution with 1 stide
x = self.depthwise_conv(x)
print(x.shape)
Expand All @@ -89,7 +88,7 @@ def forward(self, x: Tensor) -> Tensor:
print(x.shape)

# Norm
x = self.norm(x) #+ skip
x = self.norm(x) # + skip
print(x.shape)

# Depthwise convolution with 2 stide
Expand All @@ -109,5 +108,3 @@ def forward(self, x: Tensor) -> Tensor:
x = nn.LayerNorm(w)(x)

return x


2 changes: 1 addition & 1 deletion zeta/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,5 @@
"check_cuda",
"VerboseExecution",
"seek_all_images",
"log_torch_op"
"log_torch_op",
]
5 changes: 3 additions & 2 deletions zeta/utils/log_pytorch_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
format="<green>{time}</green> <level>{message}</level>",
backtrace=True,
diagnose=True,
enqueue = True,
catch = True,
enqueue=True,
catch=True,
)


Expand All @@ -39,6 +39,7 @@ def log_torch_op(
Returns:
function: The decorated function.
"""

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down

0 comments on commit 1f8d0fd

Please sign in to comment.