Skip to content

Commit

Permalink
Merge pull request #4 from kyegomez/master
Browse files Browse the repository at this point in the history
Catching up 20240103 0918
  • Loading branch information
evelynmitchell authored Jan 3, 2024
2 parents 6cda234 + c710f42 commit f54eea2
Show file tree
Hide file tree
Showing 7 changed files with 623 additions and 3 deletions.
43 changes: 43 additions & 0 deletions tests/nn/attentions/test_agent_self_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch import nn
from zeta.nn.attention.agent_attn import AgentSelfAttention


def test_agent_self_attention_init():
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16)
assert isinstance(agent_self_attn, AgentSelfAttention)
assert agent_self_attn.scale == 64**-0.5
assert isinstance(agent_self_attn.to_qkv, nn.Sequential)
assert isinstance(agent_self_attn.to_gates, nn.Sequential)
assert isinstance(agent_self_attn.agent_tokens, nn.Parameter)
assert isinstance(agent_self_attn.qa_talking_heads, nn.Conv2d)
assert isinstance(agent_self_attn.ak_talking_heads, nn.Conv2d)
assert isinstance(agent_self_attn.qa_dropout, nn.Dropout)
assert isinstance(agent_self_attn.ak_dropout, nn.Dropout)
assert isinstance(agent_self_attn.to_out, nn.Sequential)


def test_agent_self_attention_forward():
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16)
x = torch.randn(2, 64)
output = agent_self_attn(x)
assert output.shape == x.shape


def test_agent_self_attention_forward_with_mask():
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16)
x = torch.randn(2, 64)
mask = torch.ones(2, 64).bool()
output = agent_self_attn(x, mask=mask)
assert output.shape == x.shape


def test_agent_self_attention_forward_with_agent_tokens():
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16)
x = torch.randn(2, 64)
agent_tokens = torch.randn(2, 8, 16, 64)
output, agent_gathered_tokens = agent_self_attn(
x, agent_tokens=agent_tokens, return_agent_tokens=True
)
assert output.shape == x.shape
assert agent_gathered_tokens.shape == agent_tokens.shape
67 changes: 67 additions & 0 deletions tests/quant/test_lfq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import torch.nn as nn
from zeta.quant.lfq import LFQ


def test_lfg_init():
lfg = LFQ(dim=64, codebook_size=16)
assert isinstance(lfg, LFQ)
assert lfg.dim == 64
assert lfg.codebook_dim == 4
assert lfg.num_codebooks == 1
assert lfg.keep_num_codebooks_dim is False
assert isinstance(lfg.project_in, nn.Linear)
assert isinstance(lfg.project_out, nn.Linear)
assert lfg.has_projections is False
assert isinstance(lfg.activation, nn.Identity)
assert lfg.diversity_gamma == 1.0
assert lfg.entropy_loss_weight == 0.1
assert lfg.codebook_scale == 1.0
assert lfg.commitment_loss_weight == 0.25
assert torch.all(lfg.mask == 2 ** torch.arange(3, -1, -1))
assert lfg.zero == 0.0
assert torch.all(
lfg.codebook
== lfg.bits_to_codes(
((torch.arange(16)[..., None].int() & lfg.mask) != 0).float()
)
)


def test_lfg_init_custom_params():
lfg = LFQ(
dim=128,
codebook_size=32,
entropy_loss_weight=0.2,
commitment_loss_weight=0.3,
diversity_gamma=2.0,
straight_through_activation=nn.ReLU(),
num_codebooks=2,
keep_num_codebooks_dim=True,
codebook_scale=2.0,
)
assert lfg.dim == 128
assert lfg.codebook_dim == 5
assert lfg.num_codebooks == 2
assert lfg.keep_num_codebooks_dim is True
assert isinstance(lfg.activation, nn.ReLU)
assert lfg.diversity_gamma == 2.0
assert lfg.entropy_loss_weight == 0.2
assert lfg.codebook_scale == 2.0
assert lfg.commitment_loss_weight == 0.3
assert torch.all(lfg.mask == 2 ** torch.arange(4, -1, -1))
assert torch.all(
lfg.codebook
== lfg.bits_to_codes(
((torch.arange(32)[..., None].int() & lfg.mask) != 0).float()
)
)


def test_lfq_forward():
lfq = LFQ(dim=64, codebook_size=16)
x = torch.randn(2, 64)
output, loss, _, _ = lfq(x)
assert output.shape == x.shape
assert isinstance(loss, torch.Tensor)
assert loss.dim() == 0
4 changes: 2 additions & 2 deletions zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Zeta Halo"""


from zeta.nn.attention.attend import Attend, Intermediates
from zeta.nn.attention.cross_attn_images import MultiModalCrossAttention
from zeta.nn.attention.flash_attention import FlashAttention
Expand All @@ -19,6 +17,7 @@
from zeta.nn.attention.sparse_attention import SparseAttention
from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
from zeta.nn.attention.linear_attention import LinearAttention
from zeta.nn.attention.agent_attn import AgentSelfAttention

# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo
# from zeta.nn.attention.mgqa import MGQA
Expand All @@ -40,4 +39,5 @@
"SparseAttention",
"SpatialLinearAttention",
"LinearAttention",
"AgentSelfAttention",
]
148 changes: 148 additions & 0 deletions zeta/nn/attention/agent_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import torch
from torch.nn import Module
from torch import nn, einsum

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# functions


def exists(v):
return v is not None


# main class


class AgentSelfAttention(Module):
"""
Self-attention module for agent tokens in a neural network.
Args:
dim (int): The input dimension.
num_agent_tokens (int): The number of agent tokens.
dim_head (int, optional): The dimension of each attention head. Defaults to 64.
heads (int, optional): The number of attention heads. Defaults to 8.
dropout (float, optional): The dropout rate. Defaults to 0.0.
talking_heads (bool, optional): Whether to use talking heads mechanism. Defaults to True.
gate (bool, optional): Whether to apply gating mechanism. Defaults to True.
combine_agent_tokens (bool, optional): Whether to combine agent tokens. Defaults to False.
Examples::
>>> import torch
>>> from zeta.nn.attention import AgentSelfAttention
>>> agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16)
>>> x = torch.randn(2, 64)
>>> output = agent_self_attn(x)
>>> output.shape
torch.Size([2, 64])
"""

def __init__(
self,
dim,
*,
num_agent_tokens,
dim_head=64,
heads=8,
dropout=0.0,
talking_heads=True,
gate=True,
combine_agent_tokens=False,
):
super().__init__()
self.scale = dim_head**-0.5
dim_inner = dim_head * heads

self.to_qkv = nn.Sequential(
nn.Linear(dim, dim_inner * 3, bias=False),
Rearrange("b n (qkv h d) -> qkv b h n d", h=heads, qkv=3),
)

self.to_gates = (
nn.Sequential(
nn.Linear(dim, heads),
Rearrange("b n h -> b h n 1"),
nn.Sigmoid(),
)
if gate
else None
)

self.agent_tokens = nn.Parameter(
torch.zeros(heads, num_agent_tokens, dim_head)
)
nn.init.normal_(self.agent_tokens, std=0.02)

self.qa_talking_heads = (
nn.Conv2d(heads, heads, 1, bias=False)
if talking_heads
else nn.Identity()
)
self.ak_talking_heads = (
nn.Conv2d(heads, heads, 1, bias=False)
if talking_heads
else nn.Identity()
)

self.qa_dropout = nn.Dropout(dropout)
self.ak_dropout = nn.Dropout(dropout)

self.to_out = nn.Sequential(
Rearrange("b h n d -> b n (h d)"),
nn.Linear(dim_inner, dim, bias=False),
)

def forward(
self, x, mask=None, agent_tokens=None, return_agent_tokens=False
):
batch = x.shape[0]

q, k, v = self.to_qkv(x)

if exists(agent_tokens):
a = agent_tokens
else:
a = repeat(self.agent_tokens, "h m d -> b h m d", b=batch)

a = a * self.scale

qa_sim = einsum("b h i d, b h j d -> b h i j", q, a)
ak_sim = einsum("b h i d, b h j d -> b h i j", a, k)

if exists(mask):
max_neg_value = -torch.finfo(qa_sim.dtype).max
ak_sim = ak_sim.masked_fill(
~rearrange(mask, "b j -> b 1 1 j"), max_neg_value
)

qa_attn = qa_sim.softmax(dim=-1)
ak_attn = ak_sim.softmax(dim=-1)

qa_attn = self.qa_dropout(qa_attn)
ak_attn = self.ak_dropout(ak_attn)

qa_attn = self.qa_talking_heads(qa_attn)
ak_attn = self.ak_talking_heads(ak_attn)

agent_gathered_tokens = einsum(
"b h i j, b h j d -> b h i d", ak_attn, v
)

out = einsum(
"b h i j, b h j d -> b h i d", qa_attn, agent_gathered_tokens
)

if exists(mask):
out = out.masked_fill(~rearrange(mask, "b n -> b 1 n 1"), 0.0)

if exists(self.to_gates):
out = out * self.to_gates(x)

out = self.to_out(out)

if not return_agent_tokens:
return out

return out, agent_gathered_tokens
3 changes: 2 additions & 1 deletion zeta/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from zeta.quant.niva import niva
from zeta.quant.absmax import absmax_quantize
from zeta.quant.half_bit_linear import HalfBitLinear

from zeta.quant.lfq import LFQ

__all__ = [
"QUIK",
Expand All @@ -15,4 +15,5 @@
"QloraLinear",
"niva",
"HalfBitLinear",
"LFQ",
]
Loading

0 comments on commit f54eea2

Please sign in to comment.