-
-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Kye
committed
Apr 26, 2024
1 parent
6a64734
commit fd16add
Showing
8 changed files
with
193 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from zeta.experimental.triton.activations import * # noqa |
Empty file.
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
if torch.cuda.is_available(): | ||
try: | ||
import triton | ||
import triton.language as tl | ||
except ImportError: | ||
print( | ||
"Triton is not installed. Please install it using `pip install" | ||
" triton`." | ||
) | ||
|
||
|
||
@triton.jit | ||
def linear_projection_kernel( | ||
X, W, Y, M, N, K, stride_x, stride_w, stride_y, BLOCK_SIZE: tl.constexpr | ||
): | ||
# Compute indices | ||
row_idx = tl.program_id(0) | ||
col_idx = tl.program_id(1) | ||
|
||
# Offsets for X, W, and Y | ||
x_off = row_idx * stride_x | ||
w_off = col_idx * stride_w | ||
y_off = row_idx * stride_y + col_idx | ||
|
||
# Dot product | ||
acc = tl.zeros((), dtype=tl.float32) | ||
for k in range(K): | ||
acc += tl.load(X + x_off + k) * tl.load(W + w_off + k) | ||
tl.store(Y + y_off, acc) | ||
|
||
|
||
class LinearTriton(nn.Module): | ||
""" | ||
A custom linear module implemented using Triton. | ||
Args: | ||
in_features (int): Number of input features. | ||
out_features (int): Number of output features. | ||
bias (bool, optional): If set to True, the module has a learnable bias. Default is True. | ||
""" | ||
|
||
def __init__(self, in_features, out_features, bias=True): | ||
super(LinearTriton, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.weight = nn.Parameter(torch.randn(out_features, in_features)) | ||
if bias: | ||
self.bias = nn.Parameter(torch.randn(out_features)) | ||
else: | ||
self.register_parameter("bias", None) | ||
|
||
def forward(self, x): | ||
""" | ||
Performs a forward pass through the linear module. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape (batch_size, in_features). | ||
Returns: | ||
torch.Tensor: Output tensor of shape (batch_size, out_features). | ||
""" | ||
# Prepare the output tensor | ||
output = torch.empty( | ||
x.shape[0], self.out_features, device=x.device, dtype=x.dtype | ||
) | ||
|
||
# Grid and block dimensions | ||
grid = (x.shape[0], self.out_features) | ||
block = 128 # Example block size | ||
|
||
# Launch the Triton kernel | ||
linear_projection_kernel[grid]( | ||
x, | ||
self.weight, | ||
output, | ||
x.shape[0], | ||
self.out_features, | ||
self.in_features, | ||
x.stride(0), | ||
self.weight.stride(0), | ||
output.stride(0), | ||
block, | ||
) | ||
|
||
# Add bias if present | ||
if self.bias is not None: | ||
output += self.bias.unsqueeze(0) # Broadcasting the bias | ||
return output | ||
|
||
|
||
# # Example usage | ||
# model = LinearTriton(128, 64).cuda() | ||
# input_tensor = torch.randn(1, 10, 128).cuda() | ||
# output_tensor = model(input_tensor) | ||
# print(output_tensor.shape) # Should be torch.Size([10, 64]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from torch import nn, Tensor | ||
|
||
|
||
class PixelShuffleDownscale(nn.Module): | ||
def __init__(self, downscale_factor: int = 2): | ||
""" | ||
Initializes a PixelShuffleDownscale module. | ||
Args: | ||
downscale_factor (int): The factor by which the input will be downscaled. | ||
Example: | ||
>>> downscale_factor = 2 | ||
>>> model = PixelShuffleDownscale(downscale_factor) | ||
>>> input_tensor = torch.rand(1, 256, 448, 448) | ||
>>> output_tensor = model(input_tensor) | ||
>>> print(output_tensor.shape) | ||
torch.Size([1, 64, 896, 896]) | ||
""" | ||
super(PixelShuffleDownscale, self).__init__() | ||
self.downscale_factor = downscale_factor | ||
# Initialize the pixel shuffle with an upscale factor which will actually be used to downscale | ||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=downscale_factor) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
""" | ||
Performs a forward pass of the PixelShuffleDownscale module. | ||
Args: | ||
x (torch.Tensor): The input tensor with shape [batch_size, channels, height, width]. | ||
Returns: | ||
torch.Tensor: The output tensor after downsampling using pixel shuffle. | ||
""" | ||
# x should have a shape of [batch_size, channels, height, width] | ||
# We first need to adapt the number of channels so that pixel shuffle can be applied | ||
batch_size, channels, height, width = x.shape | ||
new_channels = channels // (self.downscale_factor**2) | ||
if new_channels * (self.downscale_factor**2) != channels: | ||
raise ValueError( | ||
"The number of channels must be divisible by" | ||
" (downscale_factor^2)" | ||
) | ||
|
||
# Reshape x to the shape expected by pixel shuffle | ||
x = x.reshape( | ||
batch_size, new_channels, self.downscale_factor**2, height, width | ||
) | ||
x = x.permute(0, 2, 1, 3, 4).contiguous() | ||
x = x.view( | ||
batch_size, | ||
new_channels * (self.downscale_factor**2), | ||
height, | ||
width, | ||
) | ||
|
||
# Apply pixel shuffle to reduce spatial dimensions and increase channel depth | ||
x = self.pixel_shuffle(x) | ||
|
||
return x | ||
|
||
|
||
# # Example of usage | ||
# downscale_factor = ( | ||
# 2 # This factor needs to be determined based on the required reduction | ||
# ) | ||
# model = PixelShuffleDownscale(downscale_factor) | ||
# input_tensor = torch.rand(1, 256, 448, 448) # Example input tensor | ||
# output_tensor = model(input_tensor) | ||
# print(output_tensor.shape) # This will print the shape of the output tensor |