Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Implement aten::prelu #1728

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2934,7 +2934,7 @@
indices_1d = op.Reshape(indices, neg_1)
# Get weight out according to indices_1d,
new_weight = op.Gather(weight, indices_1d)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 2937 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:2937:11: "happends" is a misspelling of "happens"
new_weight = op.Mul(new_weight, op.Unsqueeze(per_sample_weights, axes=1))
weight_dim_1 = op.Reshape(op.Shape(weight, start=1), neg_1)
indices_size = op.Shape(indices_1d)
Expand Down Expand Up @@ -3074,7 +3074,7 @@
# Get weight out according to indices,
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
indices_weight = op.Gather(weight, indices)
# This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)

Check warning on line 3077 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "happends" is a misspelling of "happens" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:3077:11: "happends" is a misspelling of "happens"
indices_weight = op.Mul(indices_weight, op.Unsqueeze(per_sample_weights, axes=1))

# The element in sequence must be FLOAT32 dtype due to ORT bug
Expand Down Expand Up @@ -5564,7 +5564,7 @@
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

# TODO(justinchuby): Handle cases where type reconcilation is not enough,

Check warning on line 5567 in onnxscript/function_libs/torch_lib/ops/core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "reconcilation" is a misspelling of "reconciliation" Raw Output: ./onnxscript/function_libs/torch_lib/ops/core.py:5567:49: "reconcilation" is a misspelling of "reconciliation"
# since different ONNX operators are used based on different data types.

return op.And(self, other)
Expand Down Expand Up @@ -6565,10 +6565,20 @@
return op.Pow(self, exponent)


def aten_prelu(self: TensorType, weight: TensorType) -> TensorType:
@torch_op(("aten::prelu", "aten::_prelu_kernel"), trace_only=True)
def aten_prelu(self: TReal, weight: TReal) -> TReal:
"""prelu(Tensor self, Tensor weight) -> Tensor"""

raise NotImplementedError()
zero = op.CastLike(0, self)
rank = len(self.shape)
if rank == 0:
if len(weight.shape) == 1:
# e.g. [] * [1]
weight = op.Squeeze(weight, [-1])
elif rank >= 2:
# e.g. [5,10,5] * [10]
weight = op.Reshape(weight, [1, -1] + [1] * (rank - 2))
justinchuby marked this conversation as resolved.
Show resolved Hide resolved
return op.Add(op.Max(self, zero), op.Mul(weight, op.Min(self, zero)))


def aten_prelu_backward(
Expand Down
3 changes: 2 additions & 1 deletion tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import onnxscript
import onnxscript.evaluator
from onnxscript import ir
from onnxscript.function_libs.torch_lib import graph_building
from tests.function_libs.torch_lib import error_reproduction

Expand Down Expand Up @@ -538,7 +539,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnx.printer.to_text(onnx_model)}"
f"ONNX model is invalid. Model:\n{ir.serde.deserialize_model(onnx_model)}"
) from e

try:
Expand Down
1 change: 1 addition & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,7 @@ def _where_input_wrangler(
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
TorchLibOpInfo("ops.aten.randint", core_ops.aten_randint, nondeterministic=True),
Expand Down
Loading