Skip to content

Commit

Permalink
add support for avgpool_3d op
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalrajkannan78 committed Nov 15, 2024
1 parent 4b24864 commit dd7f329
Show file tree
Hide file tree
Showing 9 changed files with 557 additions and 20 deletions.
4 changes: 2 additions & 2 deletions forge/forge/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from .matmul import Matmul, SparseMatmul

from .convolution import Conv2d, Conv2dTranspose, Conv3d
from .pooling import MaxPool1d, MaxPool2d, MaxPool3d, AvgPool1d, AvgPool2d
from .convolution import Conv2d, Conv2dTranspose, Conv3d, Conv3d
from .pooling import MaxPool1d, MaxPool2d, MaxPool3d, AvgPool1d, AvgPool2d, AvgPool3d
from .eltwise_binary import (
Add,
Divide,
Expand Down
36 changes: 24 additions & 12 deletions forge/forge/op/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def Conv3d(
bias: Optional[Union[Tensor, Parameter]] = None,
stride: int = 1,
padding: Union[int, str, List] = "same",
dilation: int = 1,
dilation: Union[int, List] = 1,
groups: int = 1,
channel_last: bool = False,
) -> Tensor:
Expand All @@ -168,34 +168,46 @@ def Conv3d(
Op name, unique to the module, or leave blank to autoset
activations: Tensor
Input activations of shape (N, Cin, Din, iH, iW)
Input activations of shape (N, Cin, iD, iH, iW) if channel_last=False,
or (N, iD, iH, iW, Cin) if channel_last=True
weights:
Tensor
Input weights of shape (Cout, Cin / groups, kD, kH, kW)
[Tensor]
Internal Use pre-split
Optional Input weights list of shape [(weight_grouping, Cin / groups, Cout)]
of length: (K*K // weight_grouping)
bias: Tenor, optional
bias: Tensor, optional
Optional bias tensor of shape (Cout)
"""
assert not channel_last, "Decomposition for channel-last Conv3d is not added yet"

# Ensure stride, dilation, and padding are in the correct format
if isinstance(stride, int):
stride = [stride] * 3
if isinstance(dilation, int):
dilation = [dilation] * 3

padding = conv3d_padding_to_canonical(padding, (weights.shape[2], weights.shape[3], weights.shape[4]))
# Adjust padding to handle 3D dimensions
padding = conv3d_padding_to_canonical(padding, (weights.shape[0], weights.shape[3], weights.shape[4]))

# Assemble inputs list
inputs = [activations, weights]
if bias is not None:
inputs.append(bias)

attrs = stride + [dilation, groups] + padding + [channel_last]
return op(
"conv3d",
name,
*inputs,
attrs=attrs,
stride_depth=stride[0],
stride_height=stride[1],
stride_width=stride[2],
dilation_depth=dilation[0],
dilation_height=dilation[1],
dilation_width=dilation[2],
groups=groups,
padding_front=padding[0],
padding_back=padding[1],
padding_top=padding[2],
padding_bottom=padding[3],
padding_left=padding[4],
padding_right=padding[5],
channel_last=channel_last,
).get_tensor()
5 changes: 3 additions & 2 deletions forge/forge/op/eval/forge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .clip import Clip
from .cumulativesum import CumulativeSum
from .argmax import Argmax
from .convolution import Conv2d
from .convolution import Conv2d, Conv3d
from .convolution import Conv2dTranspose
from .pooling import MaxPool2d
from .cast import Cast
Expand Down Expand Up @@ -110,12 +110,13 @@
"grouped_reduce_avg": "reduce",
"conv2d": Conv2d,
"conv2d_transpose": Conv2dTranspose,
"conv3d": "convolution",
"conv3d": Conv3d,
"max_pool1d": "pooling",
"max_pool2d": MaxPool2d,
"max_pool3d": "pooling",
"avg_pool1d": "pooling",
"avg_pool2d": "pooling",
"avg_pool3d": "pooling",
"constant": "constant",
"resize2d": "resize",
"resize3d": "resize",
Expand Down
191 changes: 191 additions & 0 deletions forge/forge/op/eval/forge/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,194 @@ def is_eltwise_unary(self) -> bool:

def is_eltwise_nary(self) -> bool:
return False


class Conv3d(PyOp):
@classmethod
def create(
cls,
stride_depth,
stride_height,
stride_width,
dilation_depth,
dilation_height,
dilation_width,
groups,
padding_front,
padding_back,
padding_left,
padding_right,
padding_top,
padding_bottom,
channel_last,
):
self = cls("conv3d")
self.stride_depth = stride_depth
self.stride_height = stride_height
self.stride_width = stride_width
self.dilation_depth = dilation_depth
self.dilation_height = dilation_height
self.dilation_width = dilation_width
self.groups = groups
self.padding_front = padding_front
self.padding_back = padding_back
self.padding_left = padding_left
self.padding_right = padding_right
self.padding_top = padding_top
self.padding_bottom = padding_bottom
self.channel_last = int(channel_last)
return self

def eval(self, tensors):
assert len(tensors) <= 3, "Conv ops should have up to three inputs (input, weight, bias)"
assert len(tensors) >= 2, "Conv ops should have at least two inputs (input, weight)"
t_ops = to_torch_operands(*tensors)

activations = t_ops[0]
weights = t_ops[1]
bias = t_ops[2] if len(t_ops) == 3 else None

stride = [self.stride_depth, self.stride_height, self.stride_width]
dilation = [self.dilation_depth, self.dilation_height, self.dilation_width]
groups = self.groups
padding = [
self.padding_front,
self.padding_back,
self.padding_top,
self.padding_bottom,
self.padding_left,
self.padding_right,
]

if self.channel_last:
activations = activations.permute((0, 4, 1, 2, 3))

padded_activations = torch.nn.functional.pad(
activations,
padding,
)

result = torch.nn.functional.conv3d(
padded_activations,
weights,
bias=bias,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
)

if self.channel_last:
result = result.permute((0, 2, 3, 4, 1))

result = result.to(activations.dtype)
return result

def shape(self, tensor_shapes):

act, weight = tensor_shapes[:2]
batch_size = act[0]
cout = weight[0]

d_in = act[-4] if self.channel_last else act[-3]
h_in = act[-3] if self.channel_last else act[-2]
w_in = act[-2] if self.channel_last else act[-1]

d_numerator = d_in + (self.padding_front + self.padding_back) - self.dilation_depth * (weight[-3] - 1) - 1
d_out = math.floor(1 + (d_numerator / self.stride_depth))

h_numerator = h_in + (self.padding_top + self.padding_bottom) - self.dilation_height * (weight[-2] - 1) - 1
h_out = math.floor(1 + (h_numerator / self.stride_height))

w_numerator = w_in + (self.padding_left + self.padding_right) - self.dilation_width * (weight[-1] - 1) - 1
w_out = math.floor(1 + (w_numerator / self.stride_width))

out_shape = (
[batch_size, d_out, h_out, w_out, cout] if self.channel_last else [batch_size, cout, d_out, h_out, w_out]
)

return out_shape, []

def decompose(self, dc, inputs):
# conv3d op is not yet supported in TTNN, based on refrence from conv2d following transformations are done
# TTNN can only perform a channel-last convolution with its conv3d op.
# The TTNN conv3d requires the input to be in the shape: (N, D, H, W, C) or (1, 1, N*D*H*W, C).
# It requires the weight to be in the shape: (C_out, C_in, kernel_depth, kernel_height, kernel_width).
# It requires the bias to be in the shape: (1, 1, 1, 1, C_out).
#
# If the forge conv3d op is channel-first, we must permute the input (N, C, D, H, W) tensor to (N, D, H, W, C)
# and then transpose it back to (N, C_out, D_out, H_out, W_out) afterward.
# - This is done with three transposes
# - (N, C, D, H, W) --> transpose(-4, -3): (N, D, C, H, W) --> transpose(-3, -2): (N, D, H, C, W)
# --> transpose(-2, -1): (N, D, H, W, C)
# Afterward:
# - (N, D_out, H_out, W_out, C_out) --> transpose(-3, -2): (N, D_out, H_out, C_out, W_out)
# --> transpose(-4, -3): (N, C_out, D_out, H_out, W_out)

activations = inputs[0]
weight = inputs[1]
bias = inputs[2] if len(inputs) == 3 else None

is_channel_last = self.channel_last

if bias is not None and len(bias.shape) < len(activations.shape):
while len(bias.shape) < len(activations.shape):
bias = dc.op("unsqueeze", [bias], (0, len(bias.shape)))
is_bias_unchanged = bias is None or bias == inputs[2]

if not is_channel_last:
activations = dc.op(TransposeTM.create(dim0=-4, dim1=-3), [activations])
activations = dc.op(TransposeTM.create(dim0=-3, dim1=-2), [activations])
activations = dc.op(TransposeTM.create(dim0=-2, dim1=-1), [activations])

# Only want to re-create the Conv3d op if something has changed. Otherwise it the compiler will infinitely
# decompose the same Conv3d over and over.
if not is_bias_unchanged or not is_channel_last:

new_inputs = [activations, weight] if bias is None else [activations, weight, bias]
result = dc.op(
Conv3d.create(
self.stride_depth,
self.stride_height,
self.stride_width,
self.dilation_depth,
self.dilation_height,
self.dilation_width,
self.groups,
self.padding_front,
self.padding_back,
self.padding_left,
self.padding_right,
self.padding_top,
self.padding_bottom,
True,
),
new_inputs,
)

if not is_channel_last:
result = dc.op(TransposeTM.create(dim0=-1, dim1=-2), [result])
result = dc.op(TransposeTM.create(dim0=-2, dim1=-3), [result])
result = dc.op(TransposeTM.create(dim0=-3, dim1=-4), [result])
dc.fuse(result)

def backward(self, ac, operand, inputs, output, grad):
pass

def lower(self, lc, tensors, outputs):
pass

def is_tm(self) -> bool:
return False

def is_eltwise(self) -> bool:
return False

def is_eltwise_binary(self) -> bool:
return False

def is_eltwise_unary(self) -> bool:
return False

def is_eltwise_nary(self) -> bool:
return False
Loading

0 comments on commit dd7f329

Please sign in to comment.