Skip to content

Commit

Permalink
Implement squeeze.dim for complex inputs | feat(torchlib) (#1139)
Browse files Browse the repository at this point in the history
Discovered through pytorch/pytorch#113067
  • Loading branch information
justinchuby authored Nov 8, 2023
1 parent 662af2a commit 6d6588c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7261,6 +7261,15 @@ def aten_squeeze_dim(self: TTensor, dim: int) -> TTensor:
return result


@torch_op("aten::squeeze.dim", complex=True, trace_only=True)
def aten_squeeze_dim_complex(self: TTensor, dim: int) -> TTensor:
if dim < 0:
# Account for the complex dimension in ONNX
dim = dim - 1

return aten_squeeze_dim(self, dim)


def aten_squeeze_copy(self: TensorType) -> TensorType:
"""squeeze_copy(Tensor self) -> Tensor"""

Expand Down
9 changes: 9 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,15 @@ def _where_input_wrangler(
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
reason="this Aten overload only support one tensor as input and one int as args by design",
),
TorchLibOpInfo(
"squeeze_dim",
core_ops.aten_squeeze_dim_complex,
complex=True,
trace_only=True,
).skip(
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
reason="this Aten overload only support one tensor as input and one int as args by design",
),
TorchLibOpInfo(
"squeeze",
core_ops.aten_squeeze,
Expand Down

0 comments on commit 6d6588c

Please sign in to comment.