Skip to content

Commit

Permalink
any
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Oct 4, 2023
1 parent fd024c5 commit 259765c
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,32 @@ def aten_any_dim(self: TTensor, dim: int, keepdim: bool = False) -> BOOL:
return result


@torch_op("aten::any.dims", private=True)
def _aten_any_keep_dims(self: TTensor, keepdims: bool) -> BOOL:
"""Private implementation for when keepdims is True."""

self_rank = op.Size(op.Shape(self))

Check warning on line 478 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L478

Added line #L478 was not covered by tests
if self_rank == 0:
result = op.Cast(self, to=BOOL.dtype)

Check warning on line 480 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L480

Added line #L480 was not covered by tests
else:
self_bool = op.Cast(self, to=BOOL.dtype)
self_int = op.Cast(self_bool, to=INT64.dtype)
any_true = op.ReduceMax(self_int, keepdims=keepdims)
result = op.Cast(any_true, to=BOOL.dtype)
return result

Check warning on line 486 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L482-L486

Added lines #L482 - L486 were not covered by tests


@torch_op("aten::any.dims", trace_only=True)
def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) -> BOOL:
"""any.dims(Tensor self, int[1]? dim=None, bool keepdim=False) -> Tensor"""

if not dim:
return _aten_any_keep_dims(self, keepdim)

Check warning on line 494 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L494

Added line #L494 was not covered by tests
for d in dim:
self = aten_any_dim(self, d, keepdim)
return self

Check warning on line 497 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L496-L497

Added lines #L496 - L497 were not covered by tests


def _range_supported(dtype: int) -> bool:
"""Returns true if the dtype is supported by the ONNX Range op."""
return dtype in {
Expand Down

0 comments on commit 259765c

Please sign in to comment.