Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Add missing ops (im2col) #1757

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 129 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TFloat,
TFloatOrBFloat16,
TFloatOrUInt8,
TInt,
TReal,
TTensor,
)
Expand Down Expand Up @@ -658,16 +659,138 @@
raise NotImplementedError()


def _get_im2col_indices_along_dim(
input_d: TInt,
kernel_size_d: int,
dilation_d: int,
padding_d: int,
stride_d: int,
):
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
# Input is always 4-D (N, C, H, W)
# Calculate indices of sliding blocks along spatial dimension
# Slide kernel over input each dim d:
# each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
# with steps = stride

blocks_d = input_d + (padding_d * 2) - (dilation_d * (kernel_size_d - 1))

Check warning on line 675 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L675

Added line #L675 was not covered by tests
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved

# Stride kernel over input and find starting indices along dim d
blocks_d_indices = op.Range(0, blocks_d, stride_d)
blocks_d_indices = op.Unsqueeze(blocks_d_indices, [0])

Check warning on line 679 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L678-L679

Added lines #L678 - L679 were not covered by tests

# Apply dilation on kernel and find its indices along dim d
kernel_grid = op.Range(0, kernel_size_d * dilation_d, dilation_d)
kernel_mask = op.Unsqueeze(kernel_grid, [1])

Check warning on line 683 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L682-L683

Added lines #L682 - L683 were not covered by tests

# Broadcast and add kernel staring positions (indices) with
# kernel_grid along dim d, to get block indices along dim d
block_mask = op.Add(blocks_d_indices, kernel_mask)

Check warning on line 687 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L687

Added line #L687 was not covered by tests

return block_mask

Check warning on line 689 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L689

Added line #L689 was not covered by tests


def _get_im2col_padded_input(input, padding_h, padding_w):
# Input is always 4-D tensor (N, C, H, W)
# Padding tensor has the following format: (padding_h, padding_w)
# Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
pad = op.Concat(

Check warning on line 696 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L696

Added line #L696 was not covered by tests
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
op.Constant(value_ints=[0, 0]),
op.Unsqueeze(padding_h, [0]),
op.Unsqueeze(padding_w, [0]),
axis=0,
)
return op.Pad(input, pad)

Check warning on line 705 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L705

Added line #L705 was not covered by tests


def _get_im2col_output_shape(input, kernel_h, kernel_w):
input_shape = op.Shape(input)
batch_dim = op.Gather(input_shape, 0, axis=0)
channel_dim = op.Gather(input_shape, 1, axis=0)
channel_unfolded = op.Mul(channel_dim, kernel_h * kernel_w)

Check warning on line 712 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L709-L712

Added lines #L709 - L712 were not covered by tests

return op.Concat(

Check warning on line 714 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L714

Added line #L714 was not covered by tests
op.Unsqueeze(batch_dim, [0]),
op.Unsqueeze(channel_unfolded, [0]),
op.Constant(value_ints=[-1]),
axis=0,
)


@torch_op("aten::im2col", trace_only=True)
def aten_im2col(
self: TensorType,
self: TReal,
kernel_size: Sequence[int],
dilation: Sequence[int],
padding: Sequence[int],
stride: Sequence[int],
dilation: Sequence[int] = (1, 1),
padding: Sequence[int] = (0, 0),
stride: Sequence[int] = (1, 1),
) -> TensorType:
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor"""
"""im2col(Tensor self, int[2] kernel_size, int[2] dilation=1, int[2] padding=0, int[2] stride=1) -> Tensor"""

raise NotImplementedError()
input_shape = op.Shape(self)
input_h = op.Gather(input_shape, 2, axis=0)
input_w = op.Gather(input_shape, 3, axis=0)

Check warning on line 734 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L732-L734

Added lines #L732 - L734 were not covered by tests

if not isinstance(kernel_size, Sequence):
kernel_size = (kernel_size, kernel_size)
kernel_sizes = list(kernel_size)

Check warning on line 738 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L737-L738

Added lines #L737 - L738 were not covered by tests

if not isinstance(dilation, Sequence):
dilation = (dilation, dilation)
dilations = list(dilation)

Check warning on line 742 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L741-L742

Added lines #L741 - L742 were not covered by tests

if not isinstance(padding, Sequence):
padding = (padding, padding)
pads = list(padding)

Check warning on line 746 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L745-L746

Added lines #L745 - L746 were not covered by tests

if isinstance(stride, int):
stride = (stride, stride)
strides = list(stride)

Check warning on line 750 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L749-L750

Added lines #L749 - L750 were not covered by tests
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

stride_h, stride_w = strides[0], strides[1]
padding_h, padding_w = pads[0], pads[1]
dilation_h, dilation_w = dilations[0], dilations[1]
kernel_h, kernel_w = kernel_sizes[0], kernel_sizes[1]

Check warning on line 755 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L752-L755

Added lines #L752 - L755 were not covered by tests

blocks_row_indices = _get_im2col_indices_along_dim(

Check warning on line 757 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L757

Added line #L757 was not covered by tests
input_h, kernel_h, dilation_h, padding_h, stride_h
)
blocks_col_indices = _get_im2col_indices_along_dim(

Check warning on line 760 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L760

Added line #L760 was not covered by tests
input_w, kernel_w, dilation_w, padding_w, stride_w
)

output_shape = _get_im2col_output_shape(self, kernel_h, kernel_w)
padded_input = _get_im2col_padded_input(self, padding_h, padding_w)

Check warning on line 765 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L764-L765

Added lines #L764 - L765 were not covered by tests

# For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
# [[[[1., 2., 3.,],
# [4., 5., 6.,],
# [7., 8., 9.,]]]]
# First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[1., 2., 3.],
# [4., 5., 6.]],
# [[4., 5., 6.],
# [7., 8., 9.]]]]]
# And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
# [[[[[[1., 2.],
# [4., 5.]],
# [[2., 3.],
# [5., 6]]],
# [[[4., 5.],
# [7., 8.]],
# [[5., 6.],
# [8., 9.]]]]]]
# Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
# [[[1., 2., 4., 5.],
# [2., 3., 5., 6.],
# [4., 5., 7., 8.],
# [5., 6., 8., 9.]]]
output = op.Gather(padded_input, blocks_row_indices, axis=2)
output = op.Gather(output, blocks_col_indices, axis=4)
Comment on lines +790 to +791
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to use slice, which is faster?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use Slice, however the indices would need to be transformed to starts, ends format adding extra Reshape and Split nodes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then this lgtm. Thanks for explaining!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the extra-operations can be done at export-time, is that correct? That is, they depend only on export-time values (torch parameters == onnx attributes), and not on run-time values. If so, there is no need to encode them using onnx ops, as it can be done in Python? In other words, using Slice should be doable in the trace-mode without any extra cost?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is worth thinking through.

  • Slice is an operation that has a very regular access pattern that is easier to optimize and parallelize. But Gather is very irregular and random, harder to optimize and parallelize.
  • The cost of operations on a large input tensor dominate overall cost, not cost of constant-time operations like Reshape.
  • If we want to extract a million elements, creating the indices of these million elements seems potentially expensive, when it can be described using a slice-pattern with a few elements.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if the entire model consists of a single op-function call? I wasn't necessarily looking for something visual. Just knowing impl1 takes X time and impl2 takes Y time would be fine. The starting point would be a test-case for an op like im2col, we run its onnxscript impl exported to ORT as a model.

Copy link
Collaborator

@justinchuby justinchuby Aug 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how correlated a tiny bench is with the e2e performance? Hopefully closely?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We will need to avoid overheads (like copying tensors, eg. due to conversion, etc.). And not count session creation (which should be easy). May be even warm up. Should be doable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @shubhambhokare1 : I see this has been merged. I am concerned that the strategy used here might not be good, for reasons discussed above. Any thoughts about that? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gramalingam,

Agreed with the point about the case with a large number of elements, creating these indices and using gather might be inefficient. I think I must have missed this comment thread pre-merge. Slice might be a better option.
Will add a PR on top of this to remedy this, replacing the gathers ops with slice, I guess models using im2col should be unblocked for now.

In regards to the second point, might be a good idea to create a single-op based evaluator for kernel performance. Will experiment and add that as part of the new PR.

output = op.Transpose(output, perm=[0, 1, 2, 4, 3, 5])
return op.Reshape(output, output_shape)

Check warning on line 793 in onnxscript/function_libs/torch_lib/ops/nn.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/nn.py#L790-L793

Added lines #L790 - L793 were not covered by tests


def aten_infinitely_differentiable_gelu_backward(
Expand Down
38 changes: 38 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,35 @@ def _grid_sample_input_wrangler(
return args, kwargs


def _im2col_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Move kernel_size, dilation, padding and stride from args to kwargs
if len(args) == 5:
shubhambhokare1 marked this conversation as resolved.
Show resolved Hide resolved
# Handle stride
stride = args.pop()
if isinstance(stride, np.ndarray): # convert stride to list[int]
stride = stride.tolist()
kwargs["stride"] = stride
# Handle padding
padding = args.pop()
if isinstance(padding, np.ndarray): # convert padding to list[int]
padding = padding.tolist()
kwargs["padding"] = padding
# Handle dilation
dilation = args.pop()
if isinstance(dilation, np.ndarray): # convert dilation to list[int]
dilation = dilation.tolist()
kwargs["dilation"] = dilation
# Handle kernel_size
kernel_size = args.pop()
if isinstance(kernel_size, np.ndarray): # convert kernel_size to list[int]
kernel_size = kernel_size.tolist()
kwargs["kernel_size"] = kernel_size

return args, kwargs


def _linalg_vector_norm_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -1895,6 +1924,15 @@ def _where_input_wrangler(
tolerance={torch.float16: (8e-2, 1e-4)},
),
TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu),
TorchLibOpInfo(
"nn.functional.unfold",
nn_ops.aten_im2col,
input_wrangler=_im2col_input_wrangler,
).xfail(
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape)
or not sample.input.shape,
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
),
TorchLibOpInfo("nn.functional.linear", nn_ops.aten_linear).skip(
# input: input, args: weight, bias; so len(args) == 2 means bias is provided
matcher=lambda sample: len(sample.args) != 1,
Expand Down
Loading