From 57f1e82051a31dbaaa757a27df85dc89ec4a348a Mon Sep 17 00:00:00 2001 From: Kye Date: Fri, 29 Dec 2023 20:32:48 -0500 Subject: [PATCH] [FEAT][HalfBitLinear] [FEAT][nearest_upsample] --- file_list.txt | 38 ---------- .../auto_tests_docs/auto_docs_functions.py | 0 tests/quant/test_half_bit_linear.py | 34 +++++++++ tests/{__init__.py => test___init__.py} | 0 zeta/nn/attention/__init__.py | 2 + zeta/nn/attention/linear_attention.py | 72 +++++++++++++++++++ zeta/nn/modules/__init__.py | 1 + zeta/nn/modules/nearest_upsample.py | 20 ++++++ zeta/quant/__init__.py | 11 ++- zeta/quant/half_bit_linear.py | 61 ++++++++++++++++ 10 files changed, 200 insertions(+), 39 deletions(-) delete mode 100644 file_list.txt rename auto_docs_functions.py => scripts/auto_tests_docs/auto_docs_functions.py (100%) create mode 100644 tests/quant/test_half_bit_linear.py rename tests/{__init__.py => test___init__.py} (100%) create mode 100644 zeta/nn/attention/linear_attention.py create mode 100644 zeta/nn/modules/nearest_upsample.py create mode 100644 zeta/quant/half_bit_linear.py diff --git a/file_list.txt b/file_list.txt deleted file mode 100644 index d096b5fb..00000000 --- a/file_list.txt +++ /dev/null @@ -1,38 +0,0 @@ -- img_compose_decompose: "zeta/ops/img_compose_decompose.md" -- rearrange: "zeta/ops/rearrange.md" -- img_transpose_2daxis: "zeta/ops/img_transpose_2daxis.md" -- img_transpose: "zeta/ops/img_transpose.md" -- img_order_of_axes: "zeta/ops/img_order_of_axes.md" -- mos: "zeta/ops/mos.md" -- merge_small_dims: "zeta/ops/merge_small_dims.md" -- multi_dim_cat: "zeta/ops/multi_dim_cat.md" -- img_compose_bw: "zeta/ops/img_compose_bw.md" -- squeeze_2d_new: "zeta/ops/squeeze_2d_new.md" -- temp_softmax: "zeta/ops/temp_softmax.md" -- gumbelmax: "zeta/ops/gumbelmax.md" -- _matrix_inverse_root_newton: "zeta/ops/_matrix_inverse_root_newton.md" -- compute_matrix_root_inverse_residuals: "zeta/ops/compute_matrix_root_inverse_residuals.md" -- matrix_root_diagonal: "zeta/ops/matrix_root_diagonal.md" -- sparse_softmax: "zeta/ops/sparse_softmax.md" -- reshape_audio_to_text: "zeta/ops/reshape_audio_to_text.md" -- local_softmax: "zeta/ops/local_softmax.md" -- softmaxes: "zeta/ops/softmaxes.md" -- _matrix_root_eigen: "zeta/ops/_matrix_root_eigen.md" -- main: "zeta/ops/main.md" -- norm_exp_softmax: "zeta/ops/norm_exp_softmax.md" -- multi_dim_split: "zeta/ops/multi_dim_split.md" -- img_width_to_height: "zeta/ops/img_width_to_height.md" -- fast_softmax: "zeta/ops/fast_softmax.md" -- standard_softmax: "zeta/ops/standard_softmax.md" -- unitwise_norm: "zeta/ops/unitwise_norm.md" -- reshape_video_to_text: "zeta/ops/reshape_video_to_text.md" -- img_decompose: "zeta/ops/img_decompose.md" -- unsqueeze_2d_new: "zeta/ops/unsqueeze_2d_new.md" -- reshape_img_to_text: "zeta/ops/reshape_img_to_text.md" -- channel_shuffle_new: "zeta/ops/channel_shuffle_new.md" -- matrix_inverse_root: "zeta/ops/matrix_inverse_root.md" -- sparsemax: "zeta/ops/sparsemax.md" -- gram_matrix_new: "zeta/ops/gram_matrix_new.md" -- logit_scaled_softmax: "zeta/ops/logit_scaled_softmax.md" -- selu_softmax: "zeta/ops/selu_softmax.md" -- reshape_text_to_img: "zeta/ops/reshape_text_to_img.md" diff --git a/auto_docs_functions.py b/scripts/auto_tests_docs/auto_docs_functions.py similarity index 100% rename from auto_docs_functions.py rename to scripts/auto_tests_docs/auto_docs_functions.py diff --git a/tests/quant/test_half_bit_linear.py b/tests/quant/test_half_bit_linear.py new file mode 100644 index 00000000..108a3b98 --- /dev/null +++ b/tests/quant/test_half_bit_linear.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from zeta.quant.half_bit_linear import HalfBitLinear + + +def test_half_bit_linear_init(): + hbl = HalfBitLinear(10, 5) + assert isinstance(hbl, HalfBitLinear) + assert hbl.in_features == 10 + assert hbl.out_features == 5 + assert isinstance(hbl.weight, nn.Parameter) + assert isinstance(hbl.bias, nn.Parameter) + + +def test_half_bit_linear_forward(): + hbl = HalfBitLinear(10, 5) + x = torch.randn(1, 10) + output = hbl.forward(x) + assert output.shape == (1, 5) + + +def test_half_bit_linear_forward_zero_input(): + hbl = HalfBitLinear(10, 5) + x = torch.zeros(1, 10) + output = hbl.forward(x) + assert output.shape == (1, 5) + assert torch.all(output == 0) + + +def test_half_bit_linear_forward_one_input(): + hbl = HalfBitLinear(10, 5) + x = torch.ones(1, 10) + output = hbl.forward(x) + assert output.shape == (1, 5) diff --git a/tests/__init__.py b/tests/test___init__.py similarity index 100% rename from tests/__init__.py rename to tests/test___init__.py diff --git a/zeta/nn/attention/__init__.py b/zeta/nn/attention/__init__.py index 73ecf77a..b22b4e3e 100644 --- a/zeta/nn/attention/__init__.py +++ b/zeta/nn/attention/__init__.py @@ -18,6 +18,7 @@ from zeta.nn.attention.multiquery_attention import MultiQueryAttention from zeta.nn.attention.sparse_attention import SparseAttention from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention +from zeta.nn.attention.linear_attention import LinearAttention # from zeta.nn.attention.flash_attention2 import FlashAttentionTwo # from zeta.nn.attention.mgqa import MGQA @@ -38,4 +39,5 @@ "MultiModalCrossAttention", "SparseAttention", "SpatialLinearAttention", + "LinearAttention", ] diff --git a/zeta/nn/attention/linear_attention.py b/zeta/nn/attention/linear_attention.py new file mode 100644 index 00000000..a01bf345 --- /dev/null +++ b/zeta/nn/attention/linear_attention.py @@ -0,0 +1,72 @@ +import math + +from einops import rearrange +from torch import einsum, nn + +from zeta.utils import l2norm + + +class LinearAttention(nn.Module): + """ + Linear Attention module that performs attention mechanism on the input feature map. + + Args: + dim (int): The input feature map dimension. + dim_head (int, optional): The dimension of each attention head. Defaults to 32. + heads (int, optional): The number of attention heads. Defaults to 8. + **kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: The output feature map after applying linear attention. + + """ + + def __init__(self, dim: int, dim_head: int = 32, heads: int = 8, **kwargs): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + inner_dim = dim_head * heads + self.norm = nn.LayerNorm(dim) + + self.nonlin = nn.GELU() + self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False) + + self.to_out = nn.Sequential( + nn.Conv2d(inner_dim, dim, 1, bias=False), nn.LayerNorm(dim) + ) + + def forward(self, fmap): + """ + Forward pass of the LinearAttention module. + + Args: + fmap (torch.Tensor): Input feature map tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor: Output tensor after applying linear attention, of shape (batch_size, channels, height, width). + """ + h, x, y = self.heads, *fmap.shape[-2:] + seq_len = x * y + + fmap = self.norm(fmap) + q, k, v = self.to_qkv(fmap).chunk(3, dim=1) + q, k, v = map( + lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h), + (q, k, v), + ) + + q = q.softmax(dim=-1) + k = k.softmax(dim=-2) + + q = q * self.scale + v = l2norm(v) + + k, v = map(lambda t: t / math.sqrt(seq_len), (k, v)) + + context = einsum("b n d, b n e -> b d e", k, v) + out = einsum("b n d, b d e -> b n e", q, context) + out = rearrange(out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y) + + out = self.nonlin(out) + return self.to_out(out) + diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index a0e0e376..84f1ecad 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -78,6 +78,7 @@ from zeta.nn.modules.slerp_model_merger import SLERPModelMerger from zeta.nn.modules.avg_model_merger import AverageModelMerger + # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features # from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding diff --git a/zeta/nn/modules/nearest_upsample.py b/zeta/nn/modules/nearest_upsample.py new file mode 100644 index 00000000..4f2b2379 --- /dev/null +++ b/zeta/nn/modules/nearest_upsample.py @@ -0,0 +1,20 @@ +from torch import nn +from zeta.utils import default + + +def nearest_upsample(dim: int, dim_out: int = None): + """Nearest upsampling layer. + + Args: + dim (int): _description_ + dim_out (int, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + dim_out = default(dim_out, dim) + + return nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(dim, dim_out, 3, padding=1), + ) diff --git a/zeta/quant/__init__.py b/zeta/quant/__init__.py index aa16a321..225cccf1 100644 --- a/zeta/quant/__init__.py +++ b/zeta/quant/__init__.py @@ -4,6 +4,15 @@ from zeta.quant.qlora import QloraLinear from zeta.quant.niva import niva from zeta.quant.absmax import absmax_quantize +from zeta.quant.half_bit_linear import HalfBitLinear -__all__ = ["QUIK", "absmax_quantize", "BitLinear", "STE", "QloraLinear", "niva"] +__all__ = [ + "QUIK", + "absmax_quantize", + "BitLinear", + "STE", + "QloraLinear", + "niva", + "HalfBitLinear", +] diff --git a/zeta/quant/half_bit_linear.py b/zeta/quant/half_bit_linear.py new file mode 100644 index 00000000..b48f1f66 --- /dev/null +++ b/zeta/quant/half_bit_linear.py @@ -0,0 +1,61 @@ +import torch +from torch import nn, Tensor + + +class HalfBitLinear(nn.Module): + """ + A custom linear layer with half-bit quantization. + + Args: + in_features (int): Number of input features. + out_features (int): Number of output features. + + Attributes: + in_features (int): Number of input features. + out_features (int): Number of output features. + weight (torch.Tensor): Learnable weight parameters of the layer. + bias (torch.Tensor): Learnable bias parameters of the layer. + + Examples: + # Example usage + in_features = 256 + out_features = 128 + model = HalfBitLinear(in_features, out_features) + input_tensor = torch.randn(1, in_features) + output = model(input_tensor) + print(output) + + """ + + def __init__(self, in_features: int, out_features: int): + super(HalfBitLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + self.bias = nn.Parameter(torch.randn(out_features)) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the half-bit linear layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the half-bit linear transformation. + """ + # Normalize the absolute weights to be in the range [0, 1] + normalized_abs_weights = ( + torch.abs(self.weight) / torch.abs(self.weight).max() + ) + + # Stochastic quantization + quantized_weights = torch.where( + self.weight > 0, + torch.ones_like(self.weight), + torch.zeros_like(self.weight), + ) + stochastic_mask = torch.bernoulli(normalized_abs_weights).to(x.device) + quantized_weights = quantized_weights * stochastic_mask + + return nn.functional.linear(x, quantized_weights, self.bias)