diff --git a/tests/nn/attentions/test_agent_self_attn.py b/tests/nn/attentions/test_agent_self_attn.py new file mode 100644 index 00000000..c121692d --- /dev/null +++ b/tests/nn/attentions/test_agent_self_attn.py @@ -0,0 +1,43 @@ +import torch +from torch import nn +from zeta.nn.attention.agent_attn import AgentSelfAttention + + +def test_agent_self_attention_init(): + agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) + assert isinstance(agent_self_attn, AgentSelfAttention) + assert agent_self_attn.scale == 64**-0.5 + assert isinstance(agent_self_attn.to_qkv, nn.Sequential) + assert isinstance(agent_self_attn.to_gates, nn.Sequential) + assert isinstance(agent_self_attn.agent_tokens, nn.Parameter) + assert isinstance(agent_self_attn.qa_talking_heads, nn.Conv2d) + assert isinstance(agent_self_attn.ak_talking_heads, nn.Conv2d) + assert isinstance(agent_self_attn.qa_dropout, nn.Dropout) + assert isinstance(agent_self_attn.ak_dropout, nn.Dropout) + assert isinstance(agent_self_attn.to_out, nn.Sequential) + + +def test_agent_self_attention_forward(): + agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) + x = torch.randn(2, 64) + output = agent_self_attn(x) + assert output.shape == x.shape + + +def test_agent_self_attention_forward_with_mask(): + agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) + x = torch.randn(2, 64) + mask = torch.ones(2, 64).bool() + output = agent_self_attn(x, mask=mask) + assert output.shape == x.shape + + +def test_agent_self_attention_forward_with_agent_tokens(): + agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) + x = torch.randn(2, 64) + agent_tokens = torch.randn(2, 8, 16, 64) + output, agent_gathered_tokens = agent_self_attn( + x, agent_tokens=agent_tokens, return_agent_tokens=True + ) + assert output.shape == x.shape + assert agent_gathered_tokens.shape == agent_tokens.shape diff --git a/tests/quant/test_lfq.py b/tests/quant/test_lfq.py new file mode 100644 index 00000000..6da5ee2b --- /dev/null +++ b/tests/quant/test_lfq.py @@ -0,0 +1,67 @@ +import torch +import torch.nn as nn +from zeta.quant.lfq import LFQ + + +def test_lfg_init(): + lfg = LFQ(dim=64, codebook_size=16) + assert isinstance(lfg, LFQ) + assert lfg.dim == 64 + assert lfg.codebook_dim == 4 + assert lfg.num_codebooks == 1 + assert lfg.keep_num_codebooks_dim is False + assert isinstance(lfg.project_in, nn.Linear) + assert isinstance(lfg.project_out, nn.Linear) + assert lfg.has_projections is False + assert isinstance(lfg.activation, nn.Identity) + assert lfg.diversity_gamma == 1.0 + assert lfg.entropy_loss_weight == 0.1 + assert lfg.codebook_scale == 1.0 + assert lfg.commitment_loss_weight == 0.25 + assert torch.all(lfg.mask == 2 ** torch.arange(3, -1, -1)) + assert lfg.zero == 0.0 + assert torch.all( + lfg.codebook + == lfg.bits_to_codes( + ((torch.arange(16)[..., None].int() & lfg.mask) != 0).float() + ) + ) + + +def test_lfg_init_custom_params(): + lfg = LFQ( + dim=128, + codebook_size=32, + entropy_loss_weight=0.2, + commitment_loss_weight=0.3, + diversity_gamma=2.0, + straight_through_activation=nn.ReLU(), + num_codebooks=2, + keep_num_codebooks_dim=True, + codebook_scale=2.0, + ) + assert lfg.dim == 128 + assert lfg.codebook_dim == 5 + assert lfg.num_codebooks == 2 + assert lfg.keep_num_codebooks_dim is True + assert isinstance(lfg.activation, nn.ReLU) + assert lfg.diversity_gamma == 2.0 + assert lfg.entropy_loss_weight == 0.2 + assert lfg.codebook_scale == 2.0 + assert lfg.commitment_loss_weight == 0.3 + assert torch.all(lfg.mask == 2 ** torch.arange(4, -1, -1)) + assert torch.all( + lfg.codebook + == lfg.bits_to_codes( + ((torch.arange(32)[..., None].int() & lfg.mask) != 0).float() + ) + ) + + +def test_lfq_forward(): + lfq = LFQ(dim=64, codebook_size=16) + x = torch.randn(2, 64) + output, loss, _, _ = lfq(x) + assert output.shape == x.shape + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 diff --git a/zeta/nn/attention/__init__.py b/zeta/nn/attention/__init__.py index b22b4e3e..44e7c8f5 100644 --- a/zeta/nn/attention/__init__.py +++ b/zeta/nn/attention/__init__.py @@ -1,6 +1,4 @@ """Zeta Halo""" - - from zeta.nn.attention.attend import Attend, Intermediates from zeta.nn.attention.cross_attn_images import MultiModalCrossAttention from zeta.nn.attention.flash_attention import FlashAttention @@ -19,6 +17,7 @@ from zeta.nn.attention.sparse_attention import SparseAttention from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention from zeta.nn.attention.linear_attention import LinearAttention +from zeta.nn.attention.agent_attn import AgentSelfAttention # from zeta.nn.attention.flash_attention2 import FlashAttentionTwo # from zeta.nn.attention.mgqa import MGQA @@ -40,4 +39,5 @@ "SparseAttention", "SpatialLinearAttention", "LinearAttention", + "AgentSelfAttention", ] diff --git a/zeta/nn/attention/agent_attn.py b/zeta/nn/attention/agent_attn.py new file mode 100644 index 00000000..53faf38f --- /dev/null +++ b/zeta/nn/attention/agent_attn.py @@ -0,0 +1,148 @@ +import torch +from torch.nn import Module +from torch import nn, einsum + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +# functions + + +def exists(v): + return v is not None + + +# main class + + +class AgentSelfAttention(Module): + """ + Self-attention module for agent tokens in a neural network. + + Args: + dim (int): The input dimension. + num_agent_tokens (int): The number of agent tokens. + dim_head (int, optional): The dimension of each attention head. Defaults to 64. + heads (int, optional): The number of attention heads. Defaults to 8. + dropout (float, optional): The dropout rate. Defaults to 0.0. + talking_heads (bool, optional): Whether to use talking heads mechanism. Defaults to True. + gate (bool, optional): Whether to apply gating mechanism. Defaults to True. + combine_agent_tokens (bool, optional): Whether to combine agent tokens. Defaults to False. + + Examples:: + >>> import torch + >>> from zeta.nn.attention import AgentSelfAttention + >>> agent_self_attn = AgentSelfAttention(dim=64, num_agent_tokens=16) + >>> x = torch.randn(2, 64) + >>> output = agent_self_attn(x) + >>> output.shape + torch.Size([2, 64]) + """ + + def __init__( + self, + dim, + *, + num_agent_tokens, + dim_head=64, + heads=8, + dropout=0.0, + talking_heads=True, + gate=True, + combine_agent_tokens=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + dim_inner = dim_head * heads + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange("b n (qkv h d) -> qkv b h n d", h=heads, qkv=3), + ) + + self.to_gates = ( + nn.Sequential( + nn.Linear(dim, heads), + Rearrange("b n h -> b h n 1"), + nn.Sigmoid(), + ) + if gate + else None + ) + + self.agent_tokens = nn.Parameter( + torch.zeros(heads, num_agent_tokens, dim_head) + ) + nn.init.normal_(self.agent_tokens, std=0.02) + + self.qa_talking_heads = ( + nn.Conv2d(heads, heads, 1, bias=False) + if talking_heads + else nn.Identity() + ) + self.ak_talking_heads = ( + nn.Conv2d(heads, heads, 1, bias=False) + if talking_heads + else nn.Identity() + ) + + self.qa_dropout = nn.Dropout(dropout) + self.ak_dropout = nn.Dropout(dropout) + + self.to_out = nn.Sequential( + Rearrange("b h n d -> b n (h d)"), + nn.Linear(dim_inner, dim, bias=False), + ) + + def forward( + self, x, mask=None, agent_tokens=None, return_agent_tokens=False + ): + batch = x.shape[0] + + q, k, v = self.to_qkv(x) + + if exists(agent_tokens): + a = agent_tokens + else: + a = repeat(self.agent_tokens, "h m d -> b h m d", b=batch) + + a = a * self.scale + + qa_sim = einsum("b h i d, b h j d -> b h i j", q, a) + ak_sim = einsum("b h i d, b h j d -> b h i j", a, k) + + if exists(mask): + max_neg_value = -torch.finfo(qa_sim.dtype).max + ak_sim = ak_sim.masked_fill( + ~rearrange(mask, "b j -> b 1 1 j"), max_neg_value + ) + + qa_attn = qa_sim.softmax(dim=-1) + ak_attn = ak_sim.softmax(dim=-1) + + qa_attn = self.qa_dropout(qa_attn) + ak_attn = self.ak_dropout(ak_attn) + + qa_attn = self.qa_talking_heads(qa_attn) + ak_attn = self.ak_talking_heads(ak_attn) + + agent_gathered_tokens = einsum( + "b h i j, b h j d -> b h i d", ak_attn, v + ) + + out = einsum( + "b h i j, b h j d -> b h i d", qa_attn, agent_gathered_tokens + ) + + if exists(mask): + out = out.masked_fill(~rearrange(mask, "b n -> b 1 n 1"), 0.0) + + if exists(self.to_gates): + out = out * self.to_gates(x) + + out = self.to_out(out) + + if not return_agent_tokens: + return out + + return out, agent_gathered_tokens diff --git a/zeta/quant/__init__.py b/zeta/quant/__init__.py index 225cccf1..92bdcefe 100644 --- a/zeta/quant/__init__.py +++ b/zeta/quant/__init__.py @@ -5,7 +5,7 @@ from zeta.quant.niva import niva from zeta.quant.absmax import absmax_quantize from zeta.quant.half_bit_linear import HalfBitLinear - +from zeta.quant.lfq import LFQ __all__ = [ "QUIK", @@ -15,4 +15,5 @@ "QloraLinear", "niva", "HalfBitLinear", + "LFQ", ] diff --git a/zeta/quant/lfq.py b/zeta/quant/lfq.py new file mode 100644 index 00000000..d50aef97 --- /dev/null +++ b/zeta/quant/lfq.py @@ -0,0 +1,361 @@ +""" +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. +""" + +from collections import namedtuple +from math import ceil, log2 + +import torch +import torch.nn.functional as F +from einops import pack, rearrange, reduce, unpack +from torch import Tensor, einsum, nn +from torch.nn import Module + +# constants + +Return = namedtuple("Return", ["quantized", "indices", "entropy_aux_loss"]) + +LossBreakdown = namedtuple( + "LossBreakdown", ["per_sample_entropy", "batch_entropy", "commitment"] +) + +# helper functions + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# entropy + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) + + +# class + + +class LFQ(Module): + """ + Initializes the Lookup-Free Quantization (LFQ) module. + + Args: + dim (int, optional): The input dimension. If not specified, it is calculated based on the codebook size and number of codebooks. Defaults to None. + codebook_size (int, optional): The size of the codebook. If not specified, it is calculated based on the input dimension. Defaults to None. + entropy_loss_weight (float, optional): The weight for the entropy loss. Defaults to 0.1. + commitment_loss_weight (float, optional): The weight for the commitment loss. Defaults to 0.25. + diversity_gamma (float, optional): The gamma parameter for diversity regularization. Defaults to 1.0. + straight_through_activation (nn.Module, optional): The activation function to be used during the forward pass. Defaults to nn.Identity(). + num_codebooks (int, optional): The number of codebooks. Defaults to 1. + keep_num_codebooks_dim (bool, optional): Whether to keep the number of codebooks dimension. Defaults to None. + codebook_scale (float, optional): The scale factor for the codebook. Defaults to 1.0. + + Examples:: + import torch + from zeta.nn import LFQ + + # you can specify either dim or codebook_size + # if both specified, will be validated against each other + + quantizer = LFQ( + codebook_size = 65536, # codebook size, must be a power of 2 + dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + entropy_loss_weight = 0.1, # how much weight to place on entropy loss + diversity_gamma = 1. # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 + ) + + image_feats = torch.randn(1, 16, 32, 32) + + quantized, indices, entropy_aux_loss = quantizer(image_feats) + + # (1, 16, 32, 32), (1, 32, 32), (1,) + + assert image_feats.shape == quantized.shape + assert (quantized == quantizer.indices_to_codes(indices)).all() + """ + + def __init__( + self, + *, + dim=None, + codebook_size=None, + entropy_loss_weight=0.1, + commitment_loss_weight=0.25, + diversity_gamma=1.0, + straight_through_activation=nn.Identity(), + num_codebooks=1, + keep_num_codebooks_dim=None, + codebook_scale=1.0, # for residual LFQ, codebook scaled down by 2x at each layer + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists( + codebook_size + ), "either dim or codebook_size must be specified for LFQ" + assert not exists(codebook_size) or log2(codebook_size).is_integer(), ( + "your codebook size must be a power of 2 for lookup free" + f" quantization (suggested {2 ** ceil(log2(codebook_size))})" + ) + + codebook_size = default(codebook_size, lambda: 2**dim) + codebook_dim = int(log2(codebook_size)) + + codebook_dims = codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + has_projections = dim != codebook_dims + self.project_in = ( + nn.Linear(dim, codebook_dims) if has_projections else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_dims, dim) if has_projections else nn.Identity() + ) + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = codebook_dim + self.num_codebooks = num_codebooks + + keep_num_codebooks_dim = default( + keep_num_codebooks_dim, num_codebooks > 1 + ) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + # straight through activation + + self.activation = straight_through_activation + + # entropy aux loss related weights + + self.diversity_gamma = diversity_gamma + self.entropy_loss_weight = entropy_loss_weight + + # codebook scale + + self.codebook_scale = codebook_scale + + # commitment loss + + self.commitment_loss_weight = commitment_loss_weight + + # for no auxiliary loss, during inference + + self.register_buffer( + "mask", 2 ** torch.arange(codebook_dim - 1, -1, -1) + ) + self.register_buffer("zero", torch.tensor(0.0), persistent=False) + + # codes + + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = self.bits_to_codes(bits) + + self.register_buffer("codebook", codebook, persistent=False) + + def bits_to_codes(self, bits): + return bits * self.codebook_scale * 2 - self.codebook_scale + + @property + def dtype(self): + return self.codebook.dtype + + def indices_to_codes(self, indices, project_out=True): + """Indices to codes. + + Args: + indices (_type_): _description_ + project_out (bool, optional): _description_. Defaults to True. + + Returns: + _type_: _description_ + """ + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... -> ... 1") + + # indices to codes, which are bits of either -1 or 1 + + bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) + + codes = self.bits_to_codes(bits) + + codes = rearrange(codes, "... c d -> ... (c d)") + + # whether to project codes out to original dimensions + # if the input feature dimensions were not log2(codebook size) + + if project_out: + codes = self.project_out(codes) + + # rearrange codes back to original shape + + if is_img_or_video: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + def forward( + self, + x: Tensor, + inv_temperature=100.0, + return_loss_breakdown=False, + mask=None, + ) -> Tensor: + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + is_img_or_video = x.ndim >= 4 + + # standardize image or video into (batch, seq, dimension) + + if is_img_or_video: + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack_one(x, "b * d") + + assert ( + x.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but received {x.shape[-1]}" + + x = self.project_in(x) + + # split out number of codebooks + + x = rearrange(x, "b n (c d) -> b n c d", c=self.num_codebooks) + + # quantize by eq 3. + + original_input = x + + codebook_value = torch.ones_like(x) * self.codebook_scale + quantized = torch.where(x > 0, codebook_value, -codebook_value) + + # use straight-through gradients (optionally with custom activation fn) if training + + if self.training: + x = self.activation(x) + x = x + (quantized - x).detach() + else: + x = quantized + + # calculate indices + + indices = reduce( + (x > 0).int() * self.mask.int(), "b n c d -> b n c", "sum" + ) + + # entropy aux loss + + if self.training: + # the same as euclidean distance up to a constant + distance = -2 * einsum( + "... i d, j d -> ... i j", original_input, self.codebook + ) + + prob = (-distance * inv_temperature).softmax(dim=-1) + + per_sample_entropy = entropy(prob).mean() + + # account for mask + + if exists(mask): + prob = prob[mask] + + # distribution over all available tokens in the batch + + avg_prob = reduce(prob, "... c d -> c d", "mean") + codebook_entropy = entropy(avg_prob).mean() + + # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions + # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch + + entropy_aux_loss = ( + per_sample_entropy - self.diversity_gamma * codebook_entropy + ) + else: + # if not training, just return dummy 0 + entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero + + # commit loss + + if self.training: + commit_loss = F.mse_loss( + original_input, quantized.detach(), reduction="none" + ) + + if exists(mask): + commit_loss = commit_loss[mask] + + commit_loss = commit_loss.mean() + else: + commit_loss = self.zero + + # merge back codebook dim + + x = rearrange(x, "b n c d -> b n (c d)") + + # project out to feature dimension if needed + + x = self.project_out(x) + + # reconstitute image or video dimensions + + if is_img_or_video: + x = unpack_one(x, ps, "b * d") + x = rearrange(x, "b ... d -> b d ...") + + indices = unpack_one(indices, ps, "b * c") + + # whether to remove single codebook dim + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, "... 1 -> ...") + + # complete aux loss + + aux_loss = ( + entropy_aux_loss * self.entropy_loss_weight + + commit_loss * self.commitment_loss_weight + ) + + ret = Return(x, indices, aux_loss) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown( + per_sample_entropy, codebook_entropy, commit_loss + ) diff --git a/zeta/quant/random_proj_quan.py b/zeta/quant/random_proj_quan.py new file mode 100644 index 00000000..e69de29b