diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index 5d1332315..b3591a0a8 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -131,3 +131,18 @@ def tag_arguments_with_param_schemas( raise TypeError(f"Required input/attribute '{param}' was not provided") return tagged_args, tagged_kwargs + + +def turn_to_kwargs_to_avoid_ordering( + param_schemas: Sequence[values.ParamSchema], + inputs: list[Any], + attributes: dict[str, Any], +) -> dict[str, Any]: + """Return the inputs and attributes to the order of the function signature.""" + for idx, param in enumerate(param_schemas): + if param.name not in attributes: + if param.is_variadic_input: + attributes[param.name] = inputs[idx:] + elif inputs: + attributes[param.name] = inputs.pop(0) + return attributes diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index bef78a799..daf63d86a 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -390,9 +390,6 @@ def eval_function( # type: ignore[override] else: # Python constants are scalars return 0 - elif function.traceable: - # Trace the function call instead of adding the function as a node - return function.function(*args, **kwargs) # args/kwargs are TorchScriptTensor/python built-in based param_schemas = function.param_schemas() @@ -422,6 +419,15 @@ def eval_function( # type: ignore[override] value, float ): attributes[name] = (value,) + if function.traceable: + inputs = self._graph.preprocess_inputs(inputs) + inputs = _wrap_torch_value_to_tensor(inputs) # type: ignore[assignment] + # The args and kwargs matters, as it's traced onnx function + kwargs = param_manipulation.turn_to_kwargs_to_avoid_ordering( + param_schemas, inputs, attributes + ) + # Trace the function call instead of adding the function as a node + return function.function(**kwargs) return self._graph.add_function_call(function, inputs, attributes) @@ -730,14 +736,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value: value.setDebugName(_rename_intermediate_value(value.debugName())) return value - @runtime_typing.checked - def _add_torchscript_op_call( - self, - name: str, - onnx_inputs: Sequence[ValidInputType], - onnx_attributes: Mapping[str, ValidArgumentType], - n_outputs: int, - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]: unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs) graph_inputs = [] assert isinstance(unwrapped_inputs, Sequence) @@ -761,6 +760,17 @@ def _add_torchscript_op_call( graph_inputs.append(self._add_constant_to_graph(input)) else: graph_inputs.append(input) + return graph_inputs + + @runtime_typing.checked + def _add_torchscript_op_call( + self, + name: str, + onnx_inputs: Sequence[ValidInputType], + onnx_attributes: Mapping[str, ValidArgumentType], + n_outputs: int, + ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + graph_inputs = self.preprocess_inputs(onnx_inputs) for key, value in onnx_attributes.items(): assert not isinstance( value, TorchScriptTensor diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index fab45cc42..c8573c4b4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5752,14 +5752,9 @@ def aten_nansum( def aten_narrow(self: TTensor, dim: INT64, start: INT64, length: INT64) -> TTensor: """narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)""" - if IsScalar(dim): - dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - - if IsScalar(start): - start = op.Reshape(start, op.Constant(value_ints=[-1])) - - if IsScalar(length): - length = op.Reshape(length, op.Constant(value_ints=[-1])) + dim = op.Reshape(dim, op.Constant(value_ints=[-1])) + start = op.Reshape(start, op.Constant(value_ints=[-1])) + length = op.Reshape(length, op.Constant(value_ints=[-1])) end = op.Add(start, length) return op.Slice(self, start, end, dim) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 35c691109..55e78593a 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1349,6 +1349,7 @@ def _where_input_wrangler( .xfail( variant_name="decimals_0", reason="This variant does not accept decimals", + test_class_name="TestOutputConsistencyEager", ) .xfail( variant_name="decimals_3",