diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d97f6da3b..d6c7029f6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2542,19 +2542,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True, traceable=True) -def _aten_diagonal_onnx( - self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) mask = op.CastLike(mask, self) self_t = op.Transpose(self, perm=perm) result = op.Mul(self_t, mask) @@ -2580,18 +2572,19 @@ def _aten_diagonal_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) return result @@ -2621,19 +2614,11 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True) -def _aten_diagonal_bool_onnx( - self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> BOOL: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) self_int = op.Cast(self, to=INT64.dtype) mask_int = op.Cast(mask, to=INT64.dtype) self_int_t = op.Transpose(self_int, perm=perm) @@ -2660,18 +2645,19 @@ def _aten_diagonal_bool_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) result = op.Cast(result, to=BOOL.dtype)