-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT][HalfBitLinear] [FEAT][nearest_upsample]
- Loading branch information
Kye
committed
Dec 30, 2023
1 parent
ddcdc19
commit 57f1e82
Showing
10 changed files
with
200 additions
and
39 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch.nn as nn | ||
from zeta.quant.half_bit_linear import HalfBitLinear | ||
|
||
|
||
def test_half_bit_linear_init(): | ||
hbl = HalfBitLinear(10, 5) | ||
assert isinstance(hbl, HalfBitLinear) | ||
assert hbl.in_features == 10 | ||
assert hbl.out_features == 5 | ||
assert isinstance(hbl.weight, nn.Parameter) | ||
assert isinstance(hbl.bias, nn.Parameter) | ||
|
||
|
||
def test_half_bit_linear_forward(): | ||
hbl = HalfBitLinear(10, 5) | ||
x = torch.randn(1, 10) | ||
output = hbl.forward(x) | ||
assert output.shape == (1, 5) | ||
|
||
|
||
def test_half_bit_linear_forward_zero_input(): | ||
hbl = HalfBitLinear(10, 5) | ||
x = torch.zeros(1, 10) | ||
output = hbl.forward(x) | ||
assert output.shape == (1, 5) | ||
assert torch.all(output == 0) | ||
|
||
|
||
def test_half_bit_linear_forward_one_input(): | ||
hbl = HalfBitLinear(10, 5) | ||
x = torch.ones(1, 10) | ||
output = hbl.forward(x) | ||
assert output.shape == (1, 5) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import math | ||
|
||
from einops import rearrange | ||
from torch import einsum, nn | ||
|
||
from zeta.utils import l2norm | ||
|
||
|
||
class LinearAttention(nn.Module): | ||
""" | ||
Linear Attention module that performs attention mechanism on the input feature map. | ||
Args: | ||
dim (int): The input feature map dimension. | ||
dim_head (int, optional): The dimension of each attention head. Defaults to 32. | ||
heads (int, optional): The number of attention heads. Defaults to 8. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
torch.Tensor: The output feature map after applying linear attention. | ||
""" | ||
|
||
def __init__(self, dim: int, dim_head: int = 32, heads: int = 8, **kwargs): | ||
super().__init__() | ||
self.scale = dim_head**-0.5 | ||
self.heads = heads | ||
inner_dim = dim_head * heads | ||
self.norm = nn.LayerNorm(dim) | ||
|
||
self.nonlin = nn.GELU() | ||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False) | ||
|
||
self.to_out = nn.Sequential( | ||
nn.Conv2d(inner_dim, dim, 1, bias=False), nn.LayerNorm(dim) | ||
) | ||
|
||
def forward(self, fmap): | ||
""" | ||
Forward pass of the LinearAttention module. | ||
Args: | ||
fmap (torch.Tensor): Input feature map tensor of shape (batch_size, channels, height, width). | ||
Returns: | ||
torch.Tensor: Output tensor after applying linear attention, of shape (batch_size, channels, height, width). | ||
""" | ||
h, x, y = self.heads, *fmap.shape[-2:] | ||
seq_len = x * y | ||
|
||
fmap = self.norm(fmap) | ||
q, k, v = self.to_qkv(fmap).chunk(3, dim=1) | ||
q, k, v = map( | ||
lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h), | ||
(q, k, v), | ||
) | ||
|
||
q = q.softmax(dim=-1) | ||
k = k.softmax(dim=-2) | ||
|
||
q = q * self.scale | ||
v = l2norm(v) | ||
|
||
k, v = map(lambda t: t / math.sqrt(seq_len), (k, v)) | ||
|
||
context = einsum("b n d, b n e -> b d e", k, v) | ||
out = einsum("b n d, b d e -> b n e", q, context) | ||
out = rearrange(out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y) | ||
|
||
out = self.nonlin(out) | ||
return self.to_out(out) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from torch import nn | ||
from zeta.utils import default | ||
|
||
|
||
def nearest_upsample(dim: int, dim_out: int = None): | ||
"""Nearest upsampling layer. | ||
Args: | ||
dim (int): _description_ | ||
dim_out (int, optional): _description_. Defaults to None. | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
dim_out = default(dim_out, dim) | ||
|
||
return nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode="nearest"), | ||
nn.Conv2d(dim, dim_out, 3, padding=1), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
from torch import nn, Tensor | ||
|
||
|
||
class HalfBitLinear(nn.Module): | ||
""" | ||
A custom linear layer with half-bit quantization. | ||
Args: | ||
in_features (int): Number of input features. | ||
out_features (int): Number of output features. | ||
Attributes: | ||
in_features (int): Number of input features. | ||
out_features (int): Number of output features. | ||
weight (torch.Tensor): Learnable weight parameters of the layer. | ||
bias (torch.Tensor): Learnable bias parameters of the layer. | ||
Examples: | ||
# Example usage | ||
in_features = 256 | ||
out_features = 128 | ||
model = HalfBitLinear(in_features, out_features) | ||
input_tensor = torch.randn(1, in_features) | ||
output = model(input_tensor) | ||
print(output) | ||
""" | ||
|
||
def __init__(self, in_features: int, out_features: int): | ||
super(HalfBitLinear, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.weight = nn.Parameter(torch.randn(out_features, in_features)) | ||
self.bias = nn.Parameter(torch.randn(out_features)) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
""" | ||
Forward pass of the half-bit linear layer. | ||
Args: | ||
x (torch.Tensor): Input tensor. | ||
Returns: | ||
torch.Tensor: Output tensor after applying the half-bit linear transformation. | ||
""" | ||
# Normalize the absolute weights to be in the range [0, 1] | ||
normalized_abs_weights = ( | ||
torch.abs(self.weight) / torch.abs(self.weight).max() | ||
) | ||
|
||
# Stochastic quantization | ||
quantized_weights = torch.where( | ||
self.weight > 0, | ||
torch.ones_like(self.weight), | ||
torch.zeros_like(self.weight), | ||
) | ||
stochastic_mask = torch.bernoulli(normalized_abs_weights).to(x.device) | ||
quantized_weights = quantized_weights * stochastic_mask | ||
|
||
return nn.functional.linear(x, quantized_weights, self.bias) |