Skip to content

Commit

Permalink
[FEAT][Adaptive Gating] [zeta.quant -> zeta.nn.quant] [FEAT][Multi-Mo…
Browse files Browse the repository at this point in the history
…dal Rotary Embeddings] [DEL][Zeta Cloud] + [DEL][ZETA CLI] [General Clean up]
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Aug 13, 2024
1 parent 4f7ae5c commit feadbfa
Show file tree
Hide file tree
Showing 42 changed files with 504 additions and 396 deletions.
Binary file added .DS_Store
Binary file not shown.
6 changes: 3 additions & 3 deletions docs/zeta/quant/bitlinear.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Performs the forward pass of the `BitLinear` module.
```python
import torch

from zeta.quant import BitLinear
from zeta.nn.quant import BitLinear

# Initialize the BitLinear module
linear = BitLinear(10, 20)
Expand All @@ -82,7 +82,7 @@ print(output.size()) # torch.Size([128, 20])
```python
import torch

from zeta.quant import BitLinear
from zeta.nn.quant import BitLinear

# Initialize the BitLinear module with 2 groups
linear = BitLinear(10, 20, groups=2)
Expand All @@ -103,7 +103,7 @@ print(output.size()) # torch.Size([128, 20])
import torch
from torch import nn

from zeta.quant import BitLinear
from zeta.nn.quant import BitLinear


class MyModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion docs/zeta/quant/qlora.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ To instantiate a QloraLinear layer:
```python
import torch.nn as nn
from zeta.quant.qlora import QloraLinear
from zeta.nn.quant.qlora import QloraLinear
in_features = 20
out_features = 30
Expand Down
2 changes: 1 addition & 1 deletion docs/zeta/quant/quik.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ In this example, we'll initialize the QUIK layer.
```python
import torch

from zeta.quant import QUIK
from zeta.nn.quant import QUIK

# Initialize the QUIK module
quik = QUIK(in_features=784, out_features=10)
Expand Down
43 changes: 37 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,51 @@
[tool.poetry]
name = "zetascale"
version = "2.5.9"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
version = "2.6.1"
description = "Rapidly Build, Optimize, and Train SOTA AI Models"
authors = ["Zeta Team <[email protected]>"]
license = "MIT"
readme = "README.md"
homepage = "https://github.com/kyegomez/zeta"
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"]
keywords = [
"artificial intelligence",
"deep learning",
"optimizers",
"Prompt Engineering",
"swarms",
"agents",
"llms",
"transformers",
"multi-agent",
"swarms of agents",
"Enterprise-Grade Agents",
"Production-Grade Agents",
"Agents",
"Multi-Grade-Agents",
"Swarms",
"Transformers",
"LLMs",
"Prompt Engineering",
"Agents",
"Generative Agents",
"Generative AI",
"Agent Marketplace",
"Agent Store",
"LSTMS",
"GRUs",
"RNNs",
"CNNs",
"MLPs",
"DNNs",
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9"
"Programming Language :: Python :: 3.10",
]


packages = [
{ include = "zeta" },
{ include = "zeta/**/*.py" },
Expand Down Expand Up @@ -65,7 +96,7 @@ target-version = ['py38']
preview = true


[tool.poetry.scripts]
zeta = 'zeta.cli.main:main'
# [tool.poetry.scripts]
# zeta = 'zeta.cli.main:main'


2 changes: 1 addition & 1 deletion tests/quant/test_bitlinear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from zeta.quant.bitlinear import BitLinear, absmax_quantize
from zeta.nn.quant.bitlinear import BitLinear, absmax_quantize


def test_bitlinear_reset_parameters():
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_half_bit_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from zeta.quant.half_bit_linear import HalfBitLinear
from zeta.nn.quant.half_bit_linear import HalfBitLinear


def test_half_bit_linear_init():
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_lfq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from zeta.quant.lfq import LFQ
from zeta.nn.quant.lfq import LFQ


def test_lfg_init():
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_niva.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

from zeta.nn import QFTSPEmbedding
from zeta.quant.niva import niva
from zeta.nn.quant.niva import niva


def test_niva_model_type():
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch.testing import assert_allclose

from zeta.quant.qlora import QloraLinear
from zeta.nn.quant.qlora import QloraLinear

# Sample instantiation values
in_features = 20
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_quik.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from zeta.quant.quick import QUIK
from zeta.nn.quant.quick import QUIK


def test_quik_initialization():
Expand Down
2 changes: 1 addition & 1 deletion tests/quant/test_resudual_vq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from zeta.quant.residual_vq import ResidualVectorQuantizer
from zeta.nn.quant.residual_vq import ResidualVectorQuantizer


def test_residual_vector_quantizer_init():
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_absmax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from zeta.quant.absmax import absmax_quantize
from zeta.nn.quant.absmax import absmax_quantize


def test_absmax_quantize_default_bits():
Expand Down
107 changes: 107 additions & 0 deletions todo/dit_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from torch import nn, Tensor
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn.modules.feedforward import FeedForward
from zeta.nn.modules.scale import Scale
from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm


def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class AdaLN(nn.Module):
"""
Adaptive Layer Normalization (AdaLN) module.
Args:
dim (int): The input dimension.
eps (float): A small value added to the denominator for numerical stability.
scale (int): The scale factor for the linear layer.
bias (bool): Whether to include a bias term in the linear layer.
"""

def __init__(
self,
dim: int = None,
eps: float = 1e-5,
scale: int = 4,
bias: bool = True,
):
super().__init__()
self.eps = eps
self.scale = scale
self.bias = bias

self.norm = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, dim * scale, bias=bias),
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the AdaLN module.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The normalized output tensor.
"""
return self.norm(x)


class DitBlock(nn.Module):
def __init__(
self,
dim: int,
dim_head: int = None,
dropout: float = 0.1,
heads: int = 8,
):
super().__init__()
self.dim = dim
self.dim_head = dim_head
self.dropout = dropout
self.heads = heads

# Attention
self.attn = MultiQueryAttention(
dim,
heads,
)

# FFN
self.input_ffn = FeedForward(dim, dim, 4, swish=True)

# Conditioning mlp
self.conditioning_mlp = FeedForward(dim, dim, 4, swish=True)

# Shift
# self.shift_op = ShiftTokens()

# Norm
self.norm = AdaptiveLayerNorm(dim)

def forward(self, x: Tensor, conditioning: Tensor) -> Tensor:

# Norm
self.norm(x)

# Scale
# scaled = modulate(
# x,
# normalize,
# normalize
# )

# return scaled
scaled = Scale(fn=self.norm)(x)
return scaled


input = torch.randn(1, 10, 512)
conditioning = torch.randn(1, 10, 512)
dit_block = DitBlock(512)
output = dit_block(input, conditioning)
print(output.shape)
Loading

0 comments on commit feadbfa

Please sign in to comment.