Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Feb 5, 2024
2 parents 0ba4110 + 2e50e6f commit 355b270
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 28 deletions.
2 changes: 1 addition & 1 deletion docs/zeta/nn/modules/dynamicroutingblock.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ drb = DynamicRoutingBlock(sb1, sb2, routing_module)
The input can be passed to this block to yield the output:

```python
x = torch.randn(10, 5)
x = torch.randn(3, 5)
y = drb(x)
```
In the process, the dynamic routing block has learned to route between `sb1` and `sb2` depending on `routing_module`'s weights, allowing the module to discover which sub-block is more 'helpful' for any given input.
Expand Down
2 changes: 2 additions & 0 deletions tests/cloud/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Test cases for the main module of the cloud package."""

import pytest
from unittest.mock import MagicMock, patch
from zeta.cloud.main import zetacloud
Expand Down
17 changes: 17 additions & 0 deletions tests/nn/attentions/test_attend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
""" Test cases for the Attend module. """

import torch
from zeta.nn.attention.attend import Attend

Expand Down Expand Up @@ -121,6 +123,21 @@ def test_attend_flash_attention():
assert out.shape == (1, 8, 32, 64)


# Test case for configuring flash attention
def test_flash_attention():
import torch
from zeta.nn import FlashAttention

q = torch.randn(2, 4, 6, 8)
k = torch.randn(2, 4, 10, 8)
v = torch.randn(2, 4, 10, 8)

attention = FlashAttention(causal=False, dropout=0.1, flash=True)
output = attention(q, k, v)

assert output.shape == (2, 4, 6, 8)


# Test case for gradient checking using torch.autograd.gradcheck
def test_attend_gradient_check():
attend = Attend()
Expand Down
Empty file removed tests/test___init__.py
Empty file.
25 changes: 0 additions & 25 deletions tests/test_init.py

This file was deleted.

1 change: 1 addition & 0 deletions zeta/cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" init file for cloud module """
from zeta.cloud.sky_api import SkyInterface
from zeta.cloud.main import zetacloud

Expand Down
2 changes: 2 additions & 0 deletions zeta/cloud/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Cloud """

import logging
from typing import Any

Expand Down
4 changes: 4 additions & 0 deletions zeta/cloud/sky_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
""" sky_api module """
""" This module provides a simplified interface for launching, executing,
stopping, starting, and tearing down clusters. """

from typing import List

import sky
Expand Down
1 change: 1 addition & 0 deletions zeta/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" Neural network modules. zeta/nn """
from zeta.nn.attention import * # noqa: F403
from zeta.nn.embeddings import * # noqa: F403
from zeta.nn.modules import * # noqa: F403
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Zeta Halo"""
"""Zeta Attention init file"""
from zeta.nn.attention.attend import Attend, Intermediates
from zeta.nn.attention.cross_attn_images import MultiModalCrossAttention
from zeta.nn.attention.flash_attention import FlashAttention
Expand Down
4 changes: 3 additions & 1 deletion zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
""" init file for nn modules """
from zeta.nn.modules.adaptive_conv import AdaptiveConv3DMod
from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm
from zeta.nn.modules.cnn_text import CNNNew
Expand Down Expand Up @@ -175,7 +176,7 @@
)
from zeta.nn.modules.qformer import QFormer
from zeta.nn.modules.poly_expert_fusion_network import MLPProjectionFusion

from zeta.nn.modules.norm_fractorals import NormalizationFractral

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -353,4 +354,5 @@
"reparameterize_aux_into_target_model",
"QFormer",
"MLPProjectionFusion",
"NormalizationFractral",
]
52 changes: 52 additions & 0 deletions zeta/nn/modules/norm_fractorals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch import nn


class NormalizationFractral(nn.Module):
"""
A module that performs normalization using fractal layers.
Args:
dim (int): The input dimension.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-8.
fi (int, optional): The number of fractal layers. Default is 4.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
fi (int): The number of fractal layers.
norm (nn.LayerNorm): The initial normalization layer.
norm_i (nn.LayerNorm): Fractal normalization layers.
"""

def __init__(
self, dim: int, eps=1e-8, fi: int = 4, *args, **kwargs # Fractal index
):
super(NormalizationFractral, self).__init__(*args, **kwargs)
self.eps = eps
self.fi = fi

self.norm = nn.LayerNorm(dim)

for i in range(fi):
setattr(self, f"norm_{i}", nn.LayerNorm(dim))

def forward(self, x):
"""
Forward pass of the NormalizationFractral module.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized output tensor.
"""
x = self.norm(x)

for i in range(self.fi):
norm = getattr(self, f"norm_{i}")
x = norm(x)

return x
8 changes: 8 additions & 0 deletions zeta/nn/modules/qformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
""" QFormer module for processing text and image inputs. """

from einops import rearrange, reduce
from torch import Tensor, nn

<<<<<<< HEAD
from zeta.nn.attention.multiquery_attention import (
MultiQueryAttention,
)
from zeta.nn.modules.simple_feedforward import SimpleFeedForward
=======
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn.modules import SimpleFeedForward

>>>>>>> 2e50e6fbb49a66ed3ef6cf19426fbbd191ca61aa
from zeta.nn.attention.cross_attention import CrossAttention


Expand Down

0 comments on commit 355b270

Please sign in to comment.