From 47ecc6cec0f518d4ebb7a2b11a8b98504eaaad5c Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Mon, 5 Aug 2024 13:14:30 -0700 Subject: [PATCH] [torchlib] Add missing ops (im2col) (#1757) Co-authored-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/nn.py | 135 +++++++++++++++++- .../function_libs/torch_lib/ops_test_data.py | 38 +++++ 2 files changed, 167 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 943390213..84f75b1a4 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -27,6 +27,7 @@ TFloat, TFloatOrBFloat16, TFloatOrUInt8, + TInt, TReal, TTensor, ) @@ -658,16 +659,138 @@ def aten_huber_loss_backward( raise NotImplementedError() +def _get_im2col_indices_along_dim( + input_d: TInt, + kernel_size_d: int, + dilation_d: int, + padding_d: int, + stride_d: int, +): + # Input is always 4-D (N, C, H, W) + # Calculate indices of sliding blocks along spatial dimension + # Slide kernel over input each dim d: + # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1) + # with steps = stride + + blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1))) + + # Stride kernel over input and find starting indices along dim d + blocks_d_indices = op.Range(0, blocks_d, stride_d) + blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0]) + + # Apply dilation on kernel and find its indices along dim d + kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d) + kernel_mask = op.Unsqueeze(kernel_grid, [1]) + + # Broadcast and add kernel staring positions (indices) with + # kernel_grid along dim d, to get block indices along dim d + block_mask = op.Add(blocks_d_indices, kernel_mask) + + return block_mask + + +def _get_im2col_padded_input(input, padding_h, padding_w): + # Input is always 4-D tensor (N, C, H, W) + # Padding tensor has the following format: (padding_h, padding_w) + # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...) + pad = op.Concat( + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + op.Constant(value_ints=[0, 0]), + op.Unsqueeze(padding_h, [0]), + op.Unsqueeze(padding_w, [0]), + axis=0, + ) + return op.Pad(input, pad) + + +def _get_im2col_output_shape(input, kernel_h, kernel_w): + input_shape = op.Shape(input) + batch_dim = op.Gather(input_shape, 0, axis=0) + channel_dim = op.Gather(input_shape, 1, axis=0) + channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w) + + return op.Concat( + op.Unsqueeze(batch_dim, [0]), + op.Unsqueeze(channel_unfolded, [0]), + op.Constant(value_ints=[-1]), + axis=0, + ) + + +@torch_op("aten::im2col", trace_only=True) def aten_im2col( - self: TensorType, + self: TReal, kernel_size: Sequence[int], - dilation: Sequence[int], - padding: Sequence[int], - stride: Sequence[int], + dilation: Sequence[int] = (1, 1), + padding: Sequence[int] = (0, 0), + stride: Sequence[int] = (1, 1), ) -> TensorType: - """im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor""" + """im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor""" - raise NotImplementedError() + input_shape = op.Shape(self) + input_h = op.Gather(input_shape, 2, axis=0) + input_w = op.Gather(input_shape, 3, axis=0) + + if not isinstance(kernel_size, Sequence): + kernel_size = (kernel_size, kernel_size) + kernel_sizes = list(kernel_size) + + if not isinstance(dilation, Sequence): + dilation = (dilation, dilation) + dilations = list(dilation) + + if not isinstance(padding, Sequence): + padding = (padding, padding) + pads = list(padding) + + if isinstance(stride, int): + stride = (stride, stride) + strides = list(stride) + + stride_h, stride_w = strides[0], strides[1] + padding_h, padding_w = pads[0], pads[1] + dilation_h, dilation_w = dilations[0], dilations[1] + kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1] + + blocks_row_indices = _get_im2col_indices_along_dim( + input_h, kernel_h, dilation_h, padding_h, stride_h + ) + blocks_col_indices = _get_im2col_indices_along_dim( + input_w, kernel_w, dilation_w, padding_w, stride_w + ) + + output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w) + padded_input = _get_im2col_padded_input(self, padding_h, padding_w) + + # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1 + # [[[[1., 2., 3.,], + # [4., 5., 6.,], + # [7., 8., 9.,]]]] + # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[1., 2., 3.], + # [4., 5., 6.]], + # [[4., 5., 6.], + # [7., 8., 9.]]]]] + # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get: + # [[[[[[1., 2.], + # [4., 5.]], + # [[2., 3.], + # [5., 6]]], + # [[[4., 5.], + # [7., 8.]], + # [[5., 6.], + # [8., 9.]]]]]] + # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get: + # [[[1., 2., 4., 5.], + # [2., 3., 5., 6.], + # [4., 5., 7., 8.], + # [5., 6., 8., 9.]]] + output = op.Gather(padded_input, blocks_row_indices, axis=2) + output = op.Gather(output, blocks_col_indices, axis=4) + output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5]) + return op.Reshape(output, output_shape) def aten_infinitely_differentiable_gelu_backward( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 9546adaa4..e0c588297 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -283,6 +283,35 @@ def _grid_sample_input_wrangler( return args, kwargs +def _im2col_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Move kernel_size, dilation, padding and stride from args to kwargs + if len(args) == 5: + # Handle stride + stride = args.pop() + if isinstance(stride, np.ndarray): # convert stride to list[int] + stride = stride.tolist() + kwargs["stride"] = stride + # Handle padding + padding = args.pop() + if isinstance(padding, np.ndarray): # convert padding to list[int] + padding = padding.tolist() + kwargs["padding"] = padding + # Handle dilation + dilation = args.pop() + if isinstance(dilation, np.ndarray): # convert dilation to list[int] + dilation = dilation.tolist() + kwargs["dilation"] = dilation + # Handle kernel_size + kernel_size = args.pop() + if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int] + kernel_size = kernel_size.tolist() + kwargs["kernel_size"] = kernel_size + + return args, kwargs + + def _linalg_vector_norm_input_wrangler( args: list[Any], kwargs: dict[str, Any] ) -> tuple[list[Any], dict[str, Any]]: @@ -1895,6 +1924,15 @@ def _where_input_wrangler( tolerance={torch.float16: (8e-2, 1e-4)}, ), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), + TorchLibOpInfo( + "nn.functional.unfold", + nn_ops.aten_im2col, + input_wrangler=_im2col_input_wrangler, + ).xfail( + matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) + or not sample.input.shape, + reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", + ), TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip( # input: input, args: weight, bias; so len(args) == 2 means bias is provided matcher=lambda sample: len(sample.args) != 1,