From de821f429cc270cbcc2e5ae8f2430f3ba782e49d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 18:09:19 +0000 Subject: [PATCH 1/3] Fix aten::diagonal --- .../function_libs/torch_lib/ops/core.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 1fc122966..71e45443b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2550,19 +2550,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True, traceable=True) -def _aten_diagonal_onnx( - self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col + dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row + dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) mask = op.CastLike(mask, self) self_t = op.Transpose(self, perm=perm) result = op.Mul(self_t, mask) @@ -2591,16 +2583,18 @@ def _aten_diagonal_onnx( if offset < 0: # row + offset length = dim1_size + offset - start = op.Constant(value_ints=[0]) + start = 0 else: # offset >= 0 # col - offset length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + start = offset # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = start + length - result = op.Slice(result, start, end, axes=axes) + result = op.Slice( + result, op.Constant(value_ints=[start]), op.Constant(value_ints=[end]), axes=axes + ) return result @@ -2678,7 +2672,7 @@ def _aten_diagonal_bool_onnx( start = op.Reshape(op.Constant(value_int=offset), neg_1) # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = start + length result = op.Slice(result, start, end, axes=axes) result = op.Cast(result, to=BOOL.dtype) From 7ab2005556a17eaf231ecf7b1ac63f9b63ac6c67 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 25 Jul 2024 19:39:51 +0000 Subject: [PATCH 2/3] Fix --- .../function_libs/torch_lib/ops/core.py | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 71e45443b..dcaf6ab84 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2550,7 +2550,6 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) @@ -2580,20 +2579,21 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset - start = 0 + length = op.Add(dim1_size, offset_val) + start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset + length = op.Sub(dim2_size, offset_val) start = offset # max(min(length, min(row, col)), 0) length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = start + length + end = op.Add(start, length) result = op.Slice( - result, op.Constant(value_ints=[start]), op.Constant(value_ints=[end]), axes=axes + result, start, end, axes=axes ) return result @@ -2623,19 +2623,10 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True) -def _aten_diagonal_bool_onnx( - self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> BOOL: - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col + dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row + dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) self_int = op.Cast(self, to=INT64.dtype) mask_int = op.Cast(mask, to=INT64.dtype) self_int_t = op.Transpose(self_int, perm=perm) @@ -2662,18 +2653,19 @@ def _aten_diagonal_bool_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset # max(min(length, min(row, col)), 0) length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = start + length + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) result = op.Cast(result, to=BOOL.dtype) From eaa733aef0b83d47d0355109a03e99d840c54ed1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Jul 2024 22:12:27 +0000 Subject: [PATCH 3/3] Use gather --- onnxscript/function_libs/torch_lib/ops/core.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d7baba735..55f37ca00 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2550,8 +2550,9 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row - dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col + neg_1 = op.Constant(value_ints=[-1]) + dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row + dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) mask = op.CastLike(mask, self) @@ -2587,14 +2588,12 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> else: # offset >= 0 # col - offset length = op.Sub(dim2_size, offset_val) - start = offset + start = offset_val # max(min(length, min(row, col)), 0) length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = op.Add(start, length) - result = op.Slice( - result, start, end, axes=axes - ) + result = op.Slice(result, start, end, axes=axes) return result @@ -2623,8 +2622,9 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row - dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col + neg_1 = op.Constant(value_ints=[-1]) + dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row + dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) self_int = op.Cast(self, to=INT64.dtype) @@ -2661,7 +2661,7 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 else: # offset >= 0 # col - offset length = op.Sub(dim2_size, offset_val) - start = offset + start = offset_val # max(min(length, min(row, col)), 0) length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))