Skip to content

Commit

Permalink
[torchlib] Fix names for registered functions (#1606)
Browse files Browse the repository at this point in the history
Fix the name for `getitem` and `adaptive*` ops.
  • Loading branch information
justinchuby authored Jun 13, 2024
1 parent 505e154 commit caf22fa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
11 changes: 5 additions & 6 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,7 +3560,7 @@ def aten_full(size: INT64, fill_value: FLOAT, dtype: int = FLOAT.dtype):


@torch_op("aten::full_like")
def aten_full_like(self, fill_value: TTensor) -> TTensor:
def aten_full_like(self: TTensor, fill_value: TTensor) -> TTensor:
"""full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

fill_value = op.CastLike(fill_value, self)
Expand All @@ -3570,7 +3570,7 @@ def aten_full_like(self, fill_value: TTensor) -> TTensor:


@torch_op("aten::full_like")
def aten_full_like_dtype(self, fill_value: TTensor, dtype: int) -> TTensor:
def aten_full_like_dtype(self: TTensor, fill_value: TTensor, dtype: int) -> TTensor:
"""full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""

fill_value = op.Cast(fill_value, to=dtype)
Expand Down Expand Up @@ -3669,8 +3669,7 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType:
raise NotImplementedError()


# NOTE: The name is made up for `getitem` to be included in the registry
@torch_op("aten::getitem")
@torch_op("_operator::getitem")
def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor:
return op.SequenceAt(self, i)

Expand Down Expand Up @@ -8174,7 +8173,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:


@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True)
def aten_transpose(self, dim0: int, dim1: int):
def aten_transpose(self: TTensor, dim0: int, dim1: int) -> TTensor:
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""

# Use trace only to construct the prem attribute in Transpose
Expand All @@ -8194,7 +8193,7 @@ def aten_transpose(self, dim0: int, dim1: int):


@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True, complex=True)
def aten_transpose_complex(self, dim0: int, dim1: int):
def aten_transpose_complex(self: TTensor, dim0: int, dim1: int) -> TTensor:
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""

# Use trace only to construct the prem attribute in Transpose
Expand Down
6 changes: 3 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
TFloatUnlessFloat32 = TypeVar("TFloatUnlessFloat32", bound=Union[BFLOAT16, FLOAT16, DOUBLE])


@torch_op("aten::aten_adaptive_avg_pool1d", traceable=True)
@torch_op("aten::adaptive_avg_pool1d", traceable=True)
def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
"""adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor"""

Expand All @@ -58,7 +58,7 @@ def aten_adaptive_avg_pool1d(self: TFloat, output_size: INT64[1]) -> TFloat:
return result


@torch_op("aten::aten_adaptive_avg_pool2d", traceable=True)
@torch_op("aten::adaptive_avg_pool2d", traceable=True)
def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
"""adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor"""

Expand All @@ -76,7 +76,7 @@ def aten_adaptive_avg_pool2d(self: TFloat, output_size: INT64[2]) -> TFloat:
return result


@torch_op("aten::aten_adaptive_avg_pool3d", traceable=True)
@torch_op("aten::adaptive_avg_pool3d", traceable=True)
def aten_adaptive_avg_pool3d(self: TFloat, output_size: INT64[3]) -> TFloat:
"""adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor"""

Expand Down

0 comments on commit caf22fa

Please sign in to comment.