From efe674d570f794f9bf322536a943c5ca61a232bd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 30 Jul 2024 21:35:23 -0700 Subject: [PATCH] [torchlib] Fix aten::diagonal (#1755) Turn aten::diagonal as trace only and fix its logic by explicitly converting python constants to onnx constants. This was needed because the exporter logic was not handling the type conversion correctly (yet) --- .../function_libs/torch_lib/ops/core.py | 42 +++++++------------ 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index d97f6da3b..d6c7029f6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -2542,19 +2542,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True, traceable=True) -def _aten_diagonal_onnx( - self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) mask = op.CastLike(mask, self) self_t = op.Transpose(self, perm=perm) result = op.Mul(self_t, mask) @@ -2580,18 +2572,19 @@ def _aten_diagonal_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) return result @@ -2621,19 +2614,11 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 # This is because computing diagonal sum is on dim2 after transpose by perm axes = [self_rank - 2] - return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes) - - -@torch_op("aten::diagonal", private=True) -def _aten_diagonal_bool_onnx( - self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int] -) -> BOOL: neg_1 = op.Constant(value_ints=[-1]) dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - tmp_tensor = op.ConstantOfShape(mask_shape) - mask = op.EyeLike(tmp_tensor, k=offset) + mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) self_int = op.Cast(self, to=INT64.dtype) mask_int = op.Cast(mask, to=INT64.dtype) self_int_t = op.Transpose(self_int, perm=perm) @@ -2660,18 +2645,19 @@ def _aten_diagonal_bool_onnx( # 6 0 4 0 # From above table, we can get the logic below + offset_val = op.Constant(value_ints=[offset]) if offset < 0: # row + offset - length = dim1_size + offset + length = op.Add(dim1_size, offset_val) start = op.Constant(value_ints=[0]) else: # offset >= 0 # col - offset - length = dim2_size - offset - start = op.Reshape(op.Constant(value_int=offset), neg_1) + length = op.Sub(dim2_size, offset_val) + start = offset_val # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), 0) - end = start + length + length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) + end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) result = op.Cast(result, to=BOOL.dtype)