-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT][Adaptive Gating] [zeta.quant -> zeta.nn.quant] [FEAT][Multi-Mo…
…dal Rotary Embeddings] [DEL][Zeta Cloud] + [DEL][ZETA CLI] [General Clean up]
- Loading branch information
Kye Gomez
authored and
Kye Gomez
committed
Aug 13, 2024
1 parent
4f7ae5c
commit feadbfa
Showing
42 changed files
with
504 additions
and
396 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,51 @@ | ||
[tool.poetry] | ||
name = "zetascale" | ||
version = "2.5.9" | ||
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" | ||
version = "2.6.1" | ||
description = "Rapidly Build, Optimize, and Train SOTA AI Models" | ||
authors = ["Zeta Team <[email protected]>"] | ||
license = "MIT" | ||
readme = "README.md" | ||
homepage = "https://github.com/kyegomez/zeta" | ||
keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] | ||
keywords = [ | ||
"artificial intelligence", | ||
"deep learning", | ||
"optimizers", | ||
"Prompt Engineering", | ||
"swarms", | ||
"agents", | ||
"llms", | ||
"transformers", | ||
"multi-agent", | ||
"swarms of agents", | ||
"Enterprise-Grade Agents", | ||
"Production-Grade Agents", | ||
"Agents", | ||
"Multi-Grade-Agents", | ||
"Swarms", | ||
"Transformers", | ||
"LLMs", | ||
"Prompt Engineering", | ||
"Agents", | ||
"Generative Agents", | ||
"Generative AI", | ||
"Agent Marketplace", | ||
"Agent Store", | ||
"LSTMS", | ||
"GRUs", | ||
"RNNs", | ||
"CNNs", | ||
"MLPs", | ||
"DNNs", | ||
] | ||
classifiers = [ | ||
"Development Status :: 4 - Beta", | ||
"Intended Audience :: Developers", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"License :: OSI Approved :: MIT License", | ||
"Programming Language :: Python :: 3.9" | ||
"Programming Language :: Python :: 3.10", | ||
] | ||
|
||
|
||
packages = [ | ||
{ include = "zeta" }, | ||
{ include = "zeta/**/*.py" }, | ||
|
@@ -65,7 +96,7 @@ target-version = ['py38'] | |
preview = true | ||
|
||
|
||
[tool.poetry.scripts] | ||
zeta = 'zeta.cli.main:main' | ||
# [tool.poetry.scripts] | ||
# zeta = 'zeta.cli.main:main' | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
from zeta.nn.attention.multiquery_attention import MultiQueryAttention | ||
from zeta.nn.modules.feedforward import FeedForward | ||
from zeta.nn.modules.scale import Scale | ||
from zeta.nn.modules.adaptive_layernorm import AdaptiveLayerNorm | ||
|
||
|
||
def modulate(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: | ||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | ||
|
||
|
||
class AdaLN(nn.Module): | ||
""" | ||
Adaptive Layer Normalization (AdaLN) module. | ||
Args: | ||
dim (int): The input dimension. | ||
eps (float): A small value added to the denominator for numerical stability. | ||
scale (int): The scale factor for the linear layer. | ||
bias (bool): Whether to include a bias term in the linear layer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int = None, | ||
eps: float = 1e-5, | ||
scale: int = 4, | ||
bias: bool = True, | ||
): | ||
super().__init__() | ||
self.eps = eps | ||
self.scale = scale | ||
self.bias = bias | ||
|
||
self.norm = nn.Sequential( | ||
nn.SiLU(), | ||
nn.Linear(dim, dim * scale, bias=bias), | ||
) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
""" | ||
Forward pass of the AdaLN module. | ||
Args: | ||
x (Tensor): The input tensor. | ||
Returns: | ||
Tensor: The normalized output tensor. | ||
""" | ||
return self.norm(x) | ||
|
||
|
||
class DitBlock(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
dim_head: int = None, | ||
dropout: float = 0.1, | ||
heads: int = 8, | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
self.dim_head = dim_head | ||
self.dropout = dropout | ||
self.heads = heads | ||
|
||
# Attention | ||
self.attn = MultiQueryAttention( | ||
dim, | ||
heads, | ||
) | ||
|
||
# FFN | ||
self.input_ffn = FeedForward(dim, dim, 4, swish=True) | ||
|
||
# Conditioning mlp | ||
self.conditioning_mlp = FeedForward(dim, dim, 4, swish=True) | ||
|
||
# Shift | ||
# self.shift_op = ShiftTokens() | ||
|
||
# Norm | ||
self.norm = AdaptiveLayerNorm(dim) | ||
|
||
def forward(self, x: Tensor, conditioning: Tensor) -> Tensor: | ||
|
||
# Norm | ||
self.norm(x) | ||
|
||
# Scale | ||
# scaled = modulate( | ||
# x, | ||
# normalize, | ||
# normalize | ||
# ) | ||
|
||
# return scaled | ||
scaled = Scale(fn=self.norm)(x) | ||
return scaled | ||
|
||
|
||
input = torch.randn(1, 10, 512) | ||
conditioning = torch.randn(1, 10, 512) | ||
dit_block = DitBlock(512) | ||
output = dit_block(input, conditioning) | ||
print(output.shape) |
Oops, something went wrong.