Skip to content

Commit

Permalink
[FEAT][PixelShuffleDownscale]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 26, 2024
1 parent 6a64734 commit fd16add
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 1 deletion.
1 change: 1 addition & 0 deletions zeta/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from zeta.experimental.triton.activations import * # noqa
Empty file.
Empty file.
Empty file.
98 changes: 98 additions & 0 deletions zeta/experimental/triton/triton_modules/linear_proj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn

if torch.cuda.is_available():
try:
import triton
import triton.language as tl
except ImportError:
print(
"Triton is not installed. Please install it using `pip install"
" triton`."
)


@triton.jit
def linear_projection_kernel(
X, W, Y, M, N, K, stride_x, stride_w, stride_y, BLOCK_SIZE: tl.constexpr
):
# Compute indices
row_idx = tl.program_id(0)
col_idx = tl.program_id(1)

# Offsets for X, W, and Y
x_off = row_idx * stride_x
w_off = col_idx * stride_w
y_off = row_idx * stride_y + col_idx

# Dot product
acc = tl.zeros((), dtype=tl.float32)
for k in range(K):
acc += tl.load(X + x_off + k) * tl.load(W + w_off + k)
tl.store(Y + y_off, acc)


class LinearTriton(nn.Module):
"""
A custom linear module implemented using Triton.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool, optional): If set to True, the module has a learnable bias. Default is True.
"""

def __init__(self, in_features, out_features, bias=True):
super(LinearTriton, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.randn(out_features, in_features))
if bias:
self.bias = nn.Parameter(torch.randn(out_features))
else:
self.register_parameter("bias", None)

def forward(self, x):
"""
Performs a forward pass through the linear module.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_features).
"""
# Prepare the output tensor
output = torch.empty(
x.shape[0], self.out_features, device=x.device, dtype=x.dtype
)

# Grid and block dimensions
grid = (x.shape[0], self.out_features)
block = 128 # Example block size

# Launch the Triton kernel
linear_projection_kernel[grid](
x,
self.weight,
output,
x.shape[0],
self.out_features,
self.in_features,
x.stride(0),
self.weight.stride(0),
output.stride(0),
block,
)

# Add bias if present
if self.bias is not None:
output += self.bias.unsqueeze(0) # Broadcasting the bias
return output


# # Example usage
# model = LinearTriton(128, 64).cuda()
# input_tensor = torch.randn(1, 10, 128).cuda()
# output_tensor = model(input_tensor)
# print(output_tensor.shape) # Should be torch.Size([10, 64])
2 changes: 2 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm

from zeta.nn.modules.query_proposal import TextHawkQueryProposal
from zeta.nn.modules.pixel_shuffling import PixelShuffleDownscale

# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -422,4 +423,5 @@
"cls_tokens",
"video_patch_linear_flatten",
"TextHawkQueryProposal",
"PixelShuffleDownscale",
]
23 changes: 22 additions & 1 deletion zeta/nn/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from zeta.nn.modules.glu import GLU
from zeta.nn.modules.swiglu import SwiGLU
from typing import Optional
from zeta.experimental.triton.triton_modules.linear_proj import LinearTriton


class ReluSquared(nn.Module):
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
zero_init_output: Optional[bool] = False,
custom_act: Optional[nn.Module] = None,
swiglu: Optional[bool] = False,
triton_kernels_on: bool = False,
):
"""
FeedForward module that applies a series of linear transformations and activations.
Expand All @@ -60,6 +62,21 @@ def __init__(
swiglu (bool, optional): Whether to use SwiGLU activation. Defaults to False.
"""
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.mult = mult
self.glu = glu
self.glu_mult_bias = glu_mult_bias
self.swish = swish
self.relu_squared = relu_squared
self.post_act_ln = post_act_ln
self.dropout = dropout
self.no_bias = no_bias
self.zero_init_output = zero_init_output
self.custom_act = custom_act
self.swiglu = swiglu
self.triton_kernels_on = triton_kernels_on

inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)

Expand All @@ -78,6 +95,10 @@ def __init__(
project_in = GLU(
dim, inner_dim, activation, mult_bias=glu_mult_bias
)
elif triton_kernels_on is True:
project_in = nn.Sequential(
LinearTriton(dim, inner_dim, bias=no_bias), activation
)
else:
project_in = nn.Sequential(
nn.Linear(dim, inner_dim, bias=not no_bias), activation
Expand All @@ -88,7 +109,7 @@ def __init__(
project_in,
nn.LayerNorm(inner_dim),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out, bias=not no_bias),
nn.Linear(inner_dim, dim_out, bias=no_bias),
)
else:
self.ff = nn.Sequential(
Expand Down
70 changes: 70 additions & 0 deletions zeta/nn/modules/pixel_shuffling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from torch import nn, Tensor


class PixelShuffleDownscale(nn.Module):
def __init__(self, downscale_factor: int = 2):
"""
Initializes a PixelShuffleDownscale module.
Args:
downscale_factor (int): The factor by which the input will be downscaled.
Example:
>>> downscale_factor = 2
>>> model = PixelShuffleDownscale(downscale_factor)
>>> input_tensor = torch.rand(1, 256, 448, 448)
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape)
torch.Size([1, 64, 896, 896])
"""
super(PixelShuffleDownscale, self).__init__()
self.downscale_factor = downscale_factor
# Initialize the pixel shuffle with an upscale factor which will actually be used to downscale
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=downscale_factor)

def forward(self, x: Tensor) -> Tensor:
"""
Performs a forward pass of the PixelShuffleDownscale module.
Args:
x (torch.Tensor): The input tensor with shape [batch_size, channels, height, width].
Returns:
torch.Tensor: The output tensor after downsampling using pixel shuffle.
"""
# x should have a shape of [batch_size, channels, height, width]
# We first need to adapt the number of channels so that pixel shuffle can be applied
batch_size, channels, height, width = x.shape
new_channels = channels // (self.downscale_factor**2)
if new_channels * (self.downscale_factor**2) != channels:
raise ValueError(
"The number of channels must be divisible by"
" (downscale_factor^2)"
)

# Reshape x to the shape expected by pixel shuffle
x = x.reshape(
batch_size, new_channels, self.downscale_factor**2, height, width
)
x = x.permute(0, 2, 1, 3, 4).contiguous()
x = x.view(
batch_size,
new_channels * (self.downscale_factor**2),
height,
width,
)

# Apply pixel shuffle to reduce spatial dimensions and increase channel depth
x = self.pixel_shuffle(x)

return x


# # Example of usage
# downscale_factor = (
# 2 # This factor needs to be determined based on the required reduction
# )
# model = PixelShuffleDownscale(downscale_factor)
# input_tensor = torch.rand(1, 256, 448, 448) # Example input tensor
# output_tensor = model(input_tensor)
# print(output_tensor.shape) # This will print the shape of the output tensor

0 comments on commit fd16add

Please sign in to comment.