Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Add missing ops (im2col) #1757

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 129 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
TTensor,
)
Expand Down Expand Up @@ -658,16 +659,138 @@ def aten_huber_loss_backward(
raise NotImplementedError()


def _get_im2col_indices_along_dim(
input_d: TInt,
kernel_size_d: int,
dilation_d: int,
padding_d: int,
stride_d: int,
):
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride

blocks_d = input_d + ((padding_d * 2) - (dilation_d * (kernel_size_d - 1)))

# Stride kernel over input and find starting indices along dim d
blocks_d_indices = op.Range(0, blocks_d, stride_d)
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])

# Apply dilation on kernel and find its indices along dim d
kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
kernel_mask = op.Unsqueeze(kernel_grid, [1])

# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
block_mask = op.Add(blocks_d_indices, kernel_mask)

return block_mask


def _get_im2col_padded_input(input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = op.Concat(
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
axis=0,
)
return op.Pad(input, pad)


def _get_im2col_output_shape(input, kernel_h, kernel_w):
input_shape = op.Shape(input)
batch_dim = op.Gather(input_shape, 0, axis=0)
channel_dim = op.Gather(input_shape, 1, axis=0)
channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w)

return op.Concat(
op.Unsqueeze(batch_dim, [0]),
op.Unsqueeze(channel_unfolded, [0]),
op.Constant(value_ints=[-1]),
axis=0,
)


@torch_op("aten::im2col", trace_only=True)
def aten_im2col(
self: TensorType,
self: TReal,
kernel_size: Sequence[int],
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
dilation: Sequence[int] = (1, 1),
padding: Sequence[int] = (0, 0),
stride: Sequence[int] = (1, 1),
) -> TensorType:
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"""
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor"""

raise NotImplementedError()
input_shape = op.Shape(self)
input_h = op.Gather(input_shape, 2, axis=0)
input_w = op.Gather(input_shape, 3, axis=0)

if not isinstance(kernel_size, Sequence):
kernel_size = (kernel_size, kernel_size)
kernel_sizes = list(kernel_size)

if not isinstance(dilation, Sequence):
dilation = (dilation, dilation)
dilations = list(dilation)

if not isinstance(padding, Sequence):
padding = (padding, padding)
pads = list(padding)

if isinstance(stride, int):
stride = (stride, stride)
strides = list(stride)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

stride_h, stride_w = strides[0], strides[1]
padding_h, padding_w = pads[0], pads[1]
dilation_h, dilation_w = dilations[0], dilations[1]
kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1]

blocks_row_indices = _get_im2col_indices_along_dim(
input_h, kernel_h, dilation_h, padding_h, stride_h
)
blocks_col_indices = _get_im2col_indices_along_dim(
input_w, kernel_w, dilation_w, padding_w, stride_w
)

output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(self, padding_h, padding_w)

# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = op.Gather(padded_input, blocks_row_indices, axis=2)
output = op.Gather(output, blocks_col_indices, axis=4)
Comment on lines +790 to +791
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to use slice, which is faster?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use Slice, however the indices would need to be transformed to starts, ends format adding extra Reshape and Split nodes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then this lgtm. Thanks for explaining!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the extra-operations can be done at export-time, is that correct? That is, they depend only on export-time values (torch parameters == onnx attributes), and not on run-time values. If so, there is no need to encode them using onnx ops, as it can be done in Python? In other words, using Slice should be doable in the trace-mode without any extra cost?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is worth thinking through.

  • Slice is an operation that has a very regular access pattern that is easier to optimize and parallelize. But Gather is very irregular and random, harder to optimize and parallelize.
  • The cost of operations on a large input tensor dominate overall cost, not cost of constant-time operations like Reshape.
  • If we want to extract a million elements, creating the indices of these million elements seems potentially expensive, when it can be described using a slice-pattern with a few elements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the entire model consists of a single op-function call? I wasn't necessarily looking for something visual. Just knowing impl1 takes X time and impl2 takes Y time would be fine. The starting point would be a test-case for an op like im2col, we run its onnxscript impl exported to ORT as a model.

Copy link
Collaborator

@justinchuby justinchuby Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how correlated a tiny bench is with the e2e performance? Hopefully closely?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We will need to avoid overheads (like copying tensors, eg. due to conversion, etc.). And not count session creation (which should be easy). May be even warm up. Should be doable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shubhambhokare1 : I see this has been merged. I am concerned that the strategy used here might not be good, for reasons discussed above. Any thoughts about that? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gramalingam,

Agreed with the point about the case with a large number of elements, creating these indices and using gather might be inefficient. I think I must have missed this comment thread pre-merge. Slice might be a better option.
Will add a PR on top of this to remedy this, replacing the gathers ops with slice, I guess models using im2col should be unblocked for now.

In regards to the second point, might be a good idea to create a single-op based evaluator for kernel performance. Will experiment and add that as part of the new PR.

output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5])
return op.Reshape(output, output_shape)


def aten_infinitely_differentiable_gelu_backward(
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,35 @@ def _grid_sample_input_wrangler(
return args, kwargs


def _im2col_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Move kernel_size, dilation, padding and stride from args to kwargs
if len(args) == 5:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
# Handle stride
stride = args.pop()
if isinstance(stride, np.ndarray): # convert stride to list[int]
stride = stride.tolist()
kwargs["stride"] = stride
# Handle padding
padding = args.pop()
if isinstance(padding, np.ndarray): # convert padding to list[int]
padding = padding.tolist()
kwargs["padding"] = padding
# Handle dilation
dilation = args.pop()
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
dilation = dilation.tolist()
kwargs["dilation"] = dilation
# Handle kernel_size
kernel_size = args.pop()
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size

return args, kwargs


def _linalg_vector_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1895,6 +1924,15 @@ def _where_input_wrangler(
tolerance={torch.float16: (8e-2, 1e-4)},
),
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
TorchLibOpInfo(
"nn.functional.unfold",
nn_ops.aten_im2col,
input_wrangler=_im2col_input_wrangler,
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
or not sample.input.shape,
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
),
TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip(
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
matcher=lambda sample: len(sample.args) != 1,
Expand Down
Loading