Skip to content

Commit

Permalink
Merge branch 'main' into rama/builder_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jul 31, 2024
2 parents 4abed66 + efe674d commit c423795
Showing 1 changed file with 14 additions and 28 deletions.
42 changes: 14 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,19 +2542,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) ->
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True, traceable=True)
def _aten_diagonal_onnx(
self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
mask = op.CastLike(mask, self)
self_t = op.Transpose(self, perm=perm)
result = op.Mul(self_t, mask)
Expand All @@ -2580,18 +2572,19 @@ def _aten_diagonal_onnx(
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset_val

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)
result = op.Slice(result, start, end, axes=axes)

return result
Expand Down Expand Up @@ -2621,19 +2614,11 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True)
def _aten_diagonal_bool_onnx(
self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> BOOL:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
self_int = op.Cast(self, to=INT64.dtype)
mask_int = op.Cast(mask, to=INT64.dtype)
self_int_t = op.Transpose(self_int, perm=perm)
Expand All @@ -2660,18 +2645,19 @@ def _aten_diagonal_bool_onnx(
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset_val

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)
result = op.Slice(result, start, end, axes=axes)
result = op.Cast(result, to=BOOL.dtype)

Expand Down

0 comments on commit c423795

Please sign in to comment.