Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Fix aten::diagonal #1755

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 21 additions & 35 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,19 +2550,10 @@
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True, traceable=True)
def _aten_diagonal_onnx(
self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can dim1 be negative? Especially -1? Same for dim2.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah it can. Thanks for pointing this out!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to use gather. Added a comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but its handled in
image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case using Shape(start, end) does not work for some reason

dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
mask = op.CastLike(mask, self)
self_t = op.Transpose(self, perm=perm)
result = op.Mul(self_t, mask)
Expand All @@ -2588,19 +2579,22 @@
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])

Check warning on line 2582 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2582

Added line #L2582 was not covered by tests
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)

Check warning on line 2585 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2585

Added line #L2585 was not covered by tests
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset

Check warning on line 2590 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2589-L2590

Added lines #L2589 - L2590 were not covered by tests

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
result = op.Slice(result, start, end, axes=axes)
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)
result = op.Slice(

Check warning on line 2595 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2593-L2595

Added lines #L2593 - L2595 were not covered by tests
result, start, end, axes=axes
)

return result

Expand Down Expand Up @@ -2629,19 +2623,10 @@
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True)
def _aten_diagonal_bool_onnx(
self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> BOOL:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
dim1_size = op.Shape(self, end=dim1, start=dim1 + 1) # row
dim2_size = op.Shape(self, end=dim2, start=dim2 + 1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
self_int = op.Cast(self, to=INT64.dtype)
mask_int = op.Cast(mask, to=INT64.dtype)
self_int_t = op.Transpose(self_int, perm=perm)
Expand All @@ -2668,18 +2653,19 @@
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])

Check warning on line 2656 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2656

Added line #L2656 was not covered by tests
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)

Check warning on line 2659 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2659

Added line #L2659 was not covered by tests
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset

Check warning on line 2664 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2663-L2664

Added lines #L2663 - L2664 were not covered by tests

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)

Check warning on line 2668 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L2667-L2668

Added lines #L2667 - L2668 were not covered by tests
result = op.Slice(result, start, end, axes=axes)
result = op.Cast(result, to=BOOL.dtype)

Expand Down
Loading