-
-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from kyegomez/master
Catching up 20240103 0918
- Loading branch information
Showing
7 changed files
with
623 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from torch import nn | ||
from zeta.nn.attention.agent_attn import AgentSelfAttention | ||
|
||
|
||
def test_agent_self_attention_init(): | ||
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) | ||
assert isinstance(agent_self_attn, AgentSelfAttention) | ||
assert agent_self_attn.scale == 64**-0.5 | ||
assert isinstance(agent_self_attn.to_qkv, nn.Sequential) | ||
assert isinstance(agent_self_attn.to_gates, nn.Sequential) | ||
assert isinstance(agent_self_attn.agent_tokens, nn.Parameter) | ||
assert isinstance(agent_self_attn.qa_talking_heads, nn.Conv2d) | ||
assert isinstance(agent_self_attn.ak_talking_heads, nn.Conv2d) | ||
assert isinstance(agent_self_attn.qa_dropout, nn.Dropout) | ||
assert isinstance(agent_self_attn.ak_dropout, nn.Dropout) | ||
assert isinstance(agent_self_attn.to_out, nn.Sequential) | ||
|
||
|
||
def test_agent_self_attention_forward(): | ||
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) | ||
x = torch.randn(2, 64) | ||
output = agent_self_attn(x) | ||
assert output.shape == x.shape | ||
|
||
|
||
def test_agent_self_attention_forward_with_mask(): | ||
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) | ||
x = torch.randn(2, 64) | ||
mask = torch.ones(2, 64).bool() | ||
output = agent_self_attn(x, mask=mask) | ||
assert output.shape == x.shape | ||
|
||
|
||
def test_agent_self_attention_forward_with_agent_tokens(): | ||
agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) | ||
x = torch.randn(2, 64) | ||
agent_tokens = torch.randn(2, 8, 16, 64) | ||
output, agent_gathered_tokens = agent_self_attn( | ||
x, agent_tokens=agent_tokens, return_agent_tokens=True | ||
) | ||
assert output.shape == x.shape | ||
assert agent_gathered_tokens.shape == agent_tokens.shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import torch.nn as nn | ||
from zeta.quant.lfq import LFQ | ||
|
||
|
||
def test_lfg_init(): | ||
lfg = LFQ(dim=64, codebook_size=16) | ||
assert isinstance(lfg, LFQ) | ||
assert lfg.dim == 64 | ||
assert lfg.codebook_dim == 4 | ||
assert lfg.num_codebooks == 1 | ||
assert lfg.keep_num_codebooks_dim is False | ||
assert isinstance(lfg.project_in, nn.Linear) | ||
assert isinstance(lfg.project_out, nn.Linear) | ||
assert lfg.has_projections is False | ||
assert isinstance(lfg.activation, nn.Identity) | ||
assert lfg.diversity_gamma == 1.0 | ||
assert lfg.entropy_loss_weight == 0.1 | ||
assert lfg.codebook_scale == 1.0 | ||
assert lfg.commitment_loss_weight == 0.25 | ||
assert torch.all(lfg.mask == 2 ** torch.arange(3, -1, -1)) | ||
assert lfg.zero == 0.0 | ||
assert torch.all( | ||
lfg.codebook | ||
== lfg.bits_to_codes( | ||
((torch.arange(16)[..., None].int() & lfg.mask) != 0).float() | ||
) | ||
) | ||
|
||
|
||
def test_lfg_init_custom_params(): | ||
lfg = LFQ( | ||
dim=128, | ||
codebook_size=32, | ||
entropy_loss_weight=0.2, | ||
commitment_loss_weight=0.3, | ||
diversity_gamma=2.0, | ||
straight_through_activation=nn.ReLU(), | ||
num_codebooks=2, | ||
keep_num_codebooks_dim=True, | ||
codebook_scale=2.0, | ||
) | ||
assert lfg.dim == 128 | ||
assert lfg.codebook_dim == 5 | ||
assert lfg.num_codebooks == 2 | ||
assert lfg.keep_num_codebooks_dim is True | ||
assert isinstance(lfg.activation, nn.ReLU) | ||
assert lfg.diversity_gamma == 2.0 | ||
assert lfg.entropy_loss_weight == 0.2 | ||
assert lfg.codebook_scale == 2.0 | ||
assert lfg.commitment_loss_weight == 0.3 | ||
assert torch.all(lfg.mask == 2 ** torch.arange(4, -1, -1)) | ||
assert torch.all( | ||
lfg.codebook | ||
== lfg.bits_to_codes( | ||
((torch.arange(32)[..., None].int() & lfg.mask) != 0).float() | ||
) | ||
) | ||
|
||
|
||
def test_lfq_forward(): | ||
lfq = LFQ(dim=64, codebook_size=16) | ||
x = torch.randn(2, 64) | ||
output, loss, _, _ = lfq(x) | ||
assert output.shape == x.shape | ||
assert isinstance(loss, torch.Tensor) | ||
assert loss.dim() == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import torch | ||
from torch.nn import Module | ||
from torch import nn, einsum | ||
|
||
from einops import rearrange, repeat | ||
from einops.layers.torch import Rearrange | ||
|
||
# functions | ||
|
||
|
||
def exists(v): | ||
return v is not None | ||
|
||
|
||
# main class | ||
|
||
|
||
class AgentSelfAttention(Module): | ||
""" | ||
Self-attention module for agent tokens in a neural network. | ||
Args: | ||
dim (int): The input dimension. | ||
num_agent_tokens (int): The number of agent tokens. | ||
dim_head (int, optional): The dimension of each attention head. Defaults to 64. | ||
heads (int, optional): The number of attention heads. Defaults to 8. | ||
dropout (float, optional): The dropout rate. Defaults to 0.0. | ||
talking_heads (bool, optional): Whether to use talking heads mechanism. Defaults to True. | ||
gate (bool, optional): Whether to apply gating mechanism. Defaults to True. | ||
combine_agent_tokens (bool, optional): Whether to combine agent tokens. Defaults to False. | ||
Examples:: | ||
>>> import torch | ||
>>> from zeta.nn.attention import AgentSelfAttention | ||
>>> agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) | ||
>>> x = torch.randn(2, 64) | ||
>>> output = agent_self_attn(x) | ||
>>> output.shape | ||
torch.Size([2, 64]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim, | ||
*, | ||
num_agent_tokens, | ||
dim_head=64, | ||
heads=8, | ||
dropout=0.0, | ||
talking_heads=True, | ||
gate=True, | ||
combine_agent_tokens=False, | ||
): | ||
super().__init__() | ||
self.scale = dim_head**-0.5 | ||
dim_inner = dim_head * heads | ||
|
||
self.to_qkv = nn.Sequential( | ||
nn.Linear(dim, dim_inner * 3, bias=False), | ||
Rearrange("b n (qkv h d) -> qkv b h n d", h=heads, qkv=3), | ||
) | ||
|
||
self.to_gates = ( | ||
nn.Sequential( | ||
nn.Linear(dim, heads), | ||
Rearrange("b n h -> b h n 1"), | ||
nn.Sigmoid(), | ||
) | ||
if gate | ||
else None | ||
) | ||
|
||
self.agent_tokens = nn.Parameter( | ||
torch.zeros(heads, num_agent_tokens, dim_head) | ||
) | ||
nn.init.normal_(self.agent_tokens, std=0.02) | ||
|
||
self.qa_talking_heads = ( | ||
nn.Conv2d(heads, heads, 1, bias=False) | ||
if talking_heads | ||
else nn.Identity() | ||
) | ||
self.ak_talking_heads = ( | ||
nn.Conv2d(heads, heads, 1, bias=False) | ||
if talking_heads | ||
else nn.Identity() | ||
) | ||
|
||
self.qa_dropout = nn.Dropout(dropout) | ||
self.ak_dropout = nn.Dropout(dropout) | ||
|
||
self.to_out = nn.Sequential( | ||
Rearrange("b h n d -> b n (h d)"), | ||
nn.Linear(dim_inner, dim, bias=False), | ||
) | ||
|
||
def forward( | ||
self, x, mask=None, agent_tokens=None, return_agent_tokens=False | ||
): | ||
batch = x.shape[0] | ||
|
||
q, k, v = self.to_qkv(x) | ||
|
||
if exists(agent_tokens): | ||
a = agent_tokens | ||
else: | ||
a = repeat(self.agent_tokens, "h m d -> b h m d", b=batch) | ||
|
||
a = a * self.scale | ||
|
||
qa_sim = einsum("b h i d, b h j d -> b h i j", q, a) | ||
ak_sim = einsum("b h i d, b h j d -> b h i j", a, k) | ||
|
||
if exists(mask): | ||
max_neg_value = -torch.finfo(qa_sim.dtype).max | ||
ak_sim = ak_sim.masked_fill( | ||
~rearrange(mask, "b j -> b 1 1 j"), max_neg_value | ||
) | ||
|
||
qa_attn = qa_sim.softmax(dim=-1) | ||
ak_attn = ak_sim.softmax(dim=-1) | ||
|
||
qa_attn = self.qa_dropout(qa_attn) | ||
ak_attn = self.ak_dropout(ak_attn) | ||
|
||
qa_attn = self.qa_talking_heads(qa_attn) | ||
ak_attn = self.ak_talking_heads(ak_attn) | ||
|
||
agent_gathered_tokens = einsum( | ||
"b h i j, b h j d -> b h i d", ak_attn, v | ||
) | ||
|
||
out = einsum( | ||
"b h i j, b h j d -> b h i d", qa_attn, agent_gathered_tokens | ||
) | ||
|
||
if exists(mask): | ||
out = out.masked_fill(~rearrange(mask, "b n -> b 1 n 1"), 0.0) | ||
|
||
if exists(self.to_gates): | ||
out = out * self.to_gates(x) | ||
|
||
out = self.to_out(out) | ||
|
||
if not return_agent_tokens: | ||
return out | ||
|
||
return out, agent_gathered_tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.