Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/complex-transpose
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Nov 7, 2023
2 parents 41d4812 + 720ea34 commit cb4add2
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 4 deletions.
10 changes: 7 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,12 +2692,16 @@ def aten_dstack(tensors: Sequence[TensorType]) -> TensorType:
raise NotImplementedError()


@torch_op("aten::einsum", trace_only=True)
def aten_einsum(
equation: str, tensors: Sequence[TensorType], path: Optional[int] = None
) -> TensorType:
equation: str,
tensors: Sequence[TReal],
path: Optional[int] = None, # pylint: disable=unused-argument
) -> TReal:
"""einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor"""

raise NotImplementedError()
# Use trace_only to unpack the `tensors` sequence
return op.Einsum(*tensors, equation=equation)


@torch_op("aten::embedding")
Expand Down
3 changes: 2 additions & 1 deletion onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
input.value = arg
onnxscript_args.append(input)
ort_inputs[input_name] = arg
elif isinstance(arg, Sequence):
elif isinstance(arg, (list, tuple)):
# str is also a sequence but we do not want to treat it as a tensor
sequence_input = []
for j, subarg in enumerate(arg):
if isinstance(subarg, np.ndarray):
Expand Down
18 changes: 18 additions & 0 deletions onnxscript/tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def _dropout_input_wrangler(
return args, kwargs


def _einsum_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
# Swap the equation and tensors to revert the special handling in the OpInfo
return [args[1], args[0]], kwargs


def _embedding_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
Expand Down Expand Up @@ -736,6 +743,17 @@ def _where_input_wrangler(
input_wrangler=_empty_input_wrangler,
nondeterministic=True,
),
TorchLibOpInfo(
"einsum", core_ops.aten_einsum, trace_only=True, input_wrangler=_einsum_input_wrangler
)
.xfail(
reason="fixme: PyTorch produces int64 output with int32 input",
dtypes=(torch.int32,),
)
.xfail(
reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739",
matcher=lambda sample: sample.args[0] == "...ik, ...j -> ij",
),
# TorchLibOpInfo("empty_strided", core_ops.aten_empty_strided), # empty_strided is not in OPS_DB
TorchLibOpInfo("eq", core_ops.aten_eq),
TorchLibOpInfo("equal", core_ops.aten_equal),
Expand Down

0 comments on commit cb4add2

Please sign in to comment.