Skip to content

Commit

Permalink
[torchlib] Add missing ops (im2col) (#1757)
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
shubhambhokare1 and justinchuby authored Aug 5, 2024
1 parent 14f88d3 commit 47ecc6c
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 6 deletions.
135 changes: 129 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
TTensor,
)
Expand Down Expand Up @@ -658,16 +659,138 @@ def aten_huber_loss_backward(
raise NotImplementedError()


def _get_im2col_indices_along_dim(
input_d: TInt,
kernel_size_d: int,
dilation_d: int,
padding_d: int,
stride_d: int,
):
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride

blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1)))

# Stride kernel over input and find starting indices along dim d
blocks_d_indices = op.Range(0, blocks_d, stride_d)
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])

# Apply dilation on kernel and find its indices along dim d
kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
kernel_mask = op.Unsqueeze(kernel_grid, [1])

# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
block_mask = op.Add(blocks_d_indices, kernel_mask)

return block_mask


def _get_im2col_padded_input(input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = op.Concat(
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
axis=0,
)
return op.Pad(input, pad)


def _get_im2col_output_shape(input, kernel_h, kernel_w):
input_shape = op.Shape(input)
batch_dim = op.Gather(input_shape, 0, axis=0)
channel_dim = op.Gather(input_shape, 1, axis=0)
channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w)

return op.Concat(
op.Unsqueeze(batch_dim, [0]),
op.Unsqueeze(channel_unfolded, [0]),
op.Constant(value_ints=[-1]),
axis=0,
)


@torch_op("aten::im2col", trace_only=True)
def aten_im2col(
self: TensorType,
self: TReal,
kernel_size: Sequence[int],
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
dilation: Sequence[int] = (1, 1),
padding: Sequence[int] = (0, 0),
stride: Sequence[int] = (1, 1),
) -> TensorType:
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"""
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor"""

raise NotImplementedError()
input_shape = op.Shape(self)
input_h = op.Gather(input_shape, 2, axis=0)
input_w = op.Gather(input_shape, 3, axis=0)

if not isinstance(kernel_size, Sequence):
kernel_size = (kernel_size, kernel_size)
kernel_sizes = list(kernel_size)

if not isinstance(dilation, Sequence):
dilation = (dilation, dilation)
dilations = list(dilation)

if not isinstance(padding, Sequence):
padding = (padding, padding)
pads = list(padding)

if isinstance(stride, int):
stride = (stride, stride)
strides = list(stride)

stride_h, stride_w = strides[0], strides[1]
padding_h, padding_w = pads[0], pads[1]
dilation_h, dilation_w = dilations[0], dilations[1]
kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1]

blocks_row_indices = _get_im2col_indices_along_dim(
input_h, kernel_h, dilation_h, padding_h, stride_h
)
blocks_col_indices = _get_im2col_indices_along_dim(
input_w, kernel_w, dilation_w, padding_w, stride_w
)

output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(self, padding_h, padding_w)

# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = op.Gather(padded_input, blocks_row_indices, axis=2)
output = op.Gather(output, blocks_col_indices, axis=4)
output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5])
return op.Reshape(output, output_shape)


def aten_infinitely_differentiable_gelu_backward(
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,35 @@ def _grid_sample_input_wrangler(
return args, kwargs


def _im2col_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Move kernel_size, dilation, padding and stride from args to kwargs
if len(args) == 5:
# Handle stride
stride = args.pop()
if isinstance(stride, np.ndarray): # convert stride to list[int]
stride = stride.tolist()
kwargs["stride"] = stride
# Handle padding
padding = args.pop()
if isinstance(padding, np.ndarray): # convert padding to list[int]
padding = padding.tolist()
kwargs["padding"] = padding
# Handle dilation
dilation = args.pop()
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
dilation = dilation.tolist()
kwargs["dilation"] = dilation
# Handle kernel_size
kernel_size = args.pop()
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size

return args, kwargs


def _linalg_vector_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1895,6 +1924,15 @@ def _where_input_wrangler(
tolerance={torch.float16: (8e-2, 1e-4)},
),
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
TorchLibOpInfo(
"nn.functional.unfold",
nn_ops.aten_im2col,
input_wrangler=_im2col_input_wrangler,
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
or not sample.input.shape,
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
),
TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip(
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
matcher=lambda sample: len(sample.args) != 1,
Expand Down

0 comments on commit 47ecc6c

Please sign in to comment.