Skip to content

Commit

Permalink
[FEAT][HalfBitLinear] [FEAT][nearest_upsample]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 30, 2023
1 parent ddcdc19 commit 57f1e82
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 39 deletions.
38 changes: 0 additions & 38 deletions file_list.txt

This file was deleted.

File renamed without changes.
34 changes: 34 additions & 0 deletions tests/quant/test_half_bit_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch.nn as nn
from zeta.quant.half_bit_linear import HalfBitLinear


def test_half_bit_linear_init():
hbl = HalfBitLinear(10, 5)
assert isinstance(hbl, HalfBitLinear)
assert hbl.in_features == 10
assert hbl.out_features == 5
assert isinstance(hbl.weight, nn.Parameter)
assert isinstance(hbl.bias, nn.Parameter)


def test_half_bit_linear_forward():
hbl = HalfBitLinear(10, 5)
x = torch.randn(1, 10)
output = hbl.forward(x)
assert output.shape == (1, 5)


def test_half_bit_linear_forward_zero_input():
hbl = HalfBitLinear(10, 5)
x = torch.zeros(1, 10)
output = hbl.forward(x)
assert output.shape == (1, 5)
assert torch.all(output == 0)


def test_half_bit_linear_forward_one_input():
hbl = HalfBitLinear(10, 5)
x = torch.ones(1, 10)
output = hbl.forward(x)
assert output.shape == (1, 5)
File renamed without changes.
2 changes: 2 additions & 0 deletions zeta/nn/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from zeta.nn.attention.multiquery_attention import MultiQueryAttention
from zeta.nn.attention.sparse_attention import SparseAttention
from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
from zeta.nn.attention.linear_attention import LinearAttention

# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo
# from zeta.nn.attention.mgqa import MGQA
Expand All @@ -38,4 +39,5 @@
"MultiModalCrossAttention",
"SparseAttention",
"SpatialLinearAttention",
"LinearAttention",
]
72 changes: 72 additions & 0 deletions zeta/nn/attention/linear_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import math

from einops import rearrange
from torch import einsum, nn

from zeta.utils import l2norm


class LinearAttention(nn.Module):
"""
Linear Attention module that performs attention mechanism on the input feature map.
Args:
dim (int): The input feature map dimension.
dim_head (int, optional): The dimension of each attention head. Defaults to 32.
heads (int, optional): The number of attention heads. Defaults to 8.
**kwargs: Additional keyword arguments.
Returns:
torch.Tensor: The output feature map after applying linear attention.
"""

def __init__(self, dim: int, dim_head: int = 32, heads: int = 8, **kwargs):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)

self.nonlin = nn.GELU()
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False)

self.to_out = nn.Sequential(
nn.Conv2d(inner_dim, dim, 1, bias=False), nn.LayerNorm(dim)
)

def forward(self, fmap):
"""
Forward pass of the LinearAttention module.
Args:
fmap (torch.Tensor): Input feature map tensor of shape (batch_size, channels, height, width).
Returns:
torch.Tensor: Output tensor after applying linear attention, of shape (batch_size, channels, height, width).
"""
h, x, y = self.heads, *fmap.shape[-2:]
seq_len = x * y

fmap = self.norm(fmap)
q, k, v = self.to_qkv(fmap).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
(q, k, v),
)

q = q.softmax(dim=-1)
k = k.softmax(dim=-2)

q = q * self.scale
v = l2norm(v)

k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))

context = einsum("b n d, b n e -> b d e", k, v)
out = einsum("b n d, b d e -> b n e", q, context)
out = rearrange(out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)

out = self.nonlin(out)
return self.to_out(out)

1 change: 1 addition & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from zeta.nn.modules.slerp_model_merger import SLERPModelMerger
from zeta.nn.modules.avg_model_merger import AverageModelMerger


# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
# from zeta.nn.modules.scaled_sinusoidal import ScaledSinuosidalEmbedding
Expand Down
20 changes: 20 additions & 0 deletions zeta/nn/modules/nearest_upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch import nn
from zeta.utils import default


def nearest_upsample(dim: int, dim_out: int = None):
"""Nearest upsampling layer.
Args:
dim (int): _description_
dim_out (int, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
dim_out = default(dim_out, dim)

return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(dim, dim_out, 3, padding=1),
)
11 changes: 10 additions & 1 deletion zeta/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from zeta.quant.qlora import QloraLinear
from zeta.quant.niva import niva
from zeta.quant.absmax import absmax_quantize
from zeta.quant.half_bit_linear import HalfBitLinear


__all__ = ["QUIK", "absmax_quantize", "BitLinear", "STE", "QloraLinear", "niva"]
__all__ = [
"QUIK",
"absmax_quantize",
"BitLinear",
"STE",
"QloraLinear",
"niva",
"HalfBitLinear",
]
61 changes: 61 additions & 0 deletions zeta/quant/half_bit_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
from torch import nn, Tensor


class HalfBitLinear(nn.Module):
"""
A custom linear layer with half-bit quantization.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
Attributes:
in_features (int): Number of input features.
out_features (int): Number of output features.
weight (torch.Tensor): Learnable weight parameters of the layer.
bias (torch.Tensor): Learnable bias parameters of the layer.
Examples:
# Example usage
in_features = 256
out_features = 128
model = HalfBitLinear(in_features, out_features)
input_tensor = torch.randn(1, in_features)
output = model(input_tensor)
print(output)
"""

def __init__(self, in_features: int, out_features: int):
super(HalfBitLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(out_features, in_features))
self.bias = nn.Parameter(torch.randn(out_features))

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the half-bit linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after applying the half-bit linear transformation.
"""
# Normalize the absolute weights to be in the range [0, 1]
normalized_abs_weights = (
torch.abs(self.weight) / torch.abs(self.weight).max()
)

# Stochastic quantization
quantized_weights = torch.where(
self.weight > 0,
torch.ones_like(self.weight),
torch.zeros_like(self.weight),
)
stochastic_mask = torch.bernoulli(normalized_abs_weights).to(x.device)
quantized_weights = quantized_weights * stochastic_mask

return nn.functional.linear(x, quantized_weights, self.bias)

0 comments on commit 57f1e82

Please sign in to comment.