Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten_prod function #1724

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2934,7 +2934,7 @@
indices_1d = op.Reshape(indices, neg_1)
# Get weight out according to indices_1d,
new_weight = op.Gather(weight, indices_1d)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 2937 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:2937:11: "happends" is a misspelling of "happens"
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)
Expand Down Expand Up @@ -3074,7 +3074,7 @@
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 3077 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:3077:11: "happends" is a misspelling of "happens"
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))

# The element in sequence must be FLOAT32 dtype due to ORT bug
Expand Down Expand Up @@ -5564,7 +5564,7 @@
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

# TODO(justinchuby): Handle cases where type reconcilation is not enough,

Check warning on line 5567 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "reconcilation" is a misspelling of "reconciliation" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:5567:49: "reconcilation" is a misspelling of "reconciliation"
# since different ONNX operators are used based on different data types.

return op.And(self, other)
Expand Down Expand Up @@ -6583,10 +6583,11 @@
raise NotImplementedError()


def aten_prod(self: TensorType, dtype: Optional[int] = None) -> TensorType:
@torch_op(("aten::prod", "aten::prod.dim_int"), trace_only=True)
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

raise NotImplementedError()
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)

Check warning on line 6590 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6590

Added line #L6590 was not covered by tests
xiaowuhu marked this conversation as resolved.
Show resolved Hide resolved


def aten_promote_types(type1: int, type2: int) -> int:
Expand Down
Loading