Skip to content

Commit

Permalink
Allow optional axes tensor to be null and ignore it as optional (#18423)
Browse files Browse the repository at this point in the history
### Description
Our function inliner converts call nodes to a proto. `Node::ToProto()`
function recreates optional NodeArgs into a `NodeProto`. While handling
missing input parameters, our inliner simply renames them as empty
strings.
`Graph::InlineFunctionProto()` recreates missing NodeArgs even though
the original call node did not have them.

This results in the below mentioned issue. The inlined model has the
following entries, notice the second argument is present, but has no
value in `ReduceSum` call (from a Dynamo exported model).

>
InsertedPrecisionFreeCast__inlfunc__aten_linalg_vector_norm_no_dim_onnx_result_12
= ReduceSum <keepdims: int = 0, noop_with_empty_axes: int = 0>
(InsertedPrecisionFreeCast__inlfunc_ReduceL1_data_abs, )

We now allow second input to ReduceSum to be nullptr and ignore it as it
is optional.

### Motivation and Context
This seeks to address
#18338
  • Loading branch information
yuslepukhin authored Nov 16, 2023
1 parent cc840c5 commit 6f863ae
Show file tree
Hide file tree
Showing 4 changed files with 896 additions and 10 deletions.
22 changes: 12 additions & 10 deletions onnxruntime/core/providers/cpu/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,21 +688,23 @@ FastReduceKind OptimizeShapeForFastReduce(gsl::span<const int64_t> input_shape,
return FastReduceKind::kNone;
}

void ValidateCommonFastReduce(const Tensor* axes_tensor) {
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");
}

// template <typename T, typename TVAL>
bool CommonFastReduceCopy(OpKernelContext* ctx, TensorShapeVector& input_axes, bool noop_with_empty_axes) {
if (ctx->InputCount() == 2) {
// second input holds the axes.
// the argument is optional
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ValidateCommonFastReduce(axes_tensor);
auto nDims = static_cast<size_t>(axes_tensor->Shape()[0]);
const auto* data = axes_tensor->Data<int64_t>();
input_axes.insert(input_axes.begin(), data, data + nDims);

if (axes_tensor != nullptr) {
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1,
"An axes tensor must be a vector tensor.");

const auto data_span = axes_tensor->DataAsSpan<int64_t>();
input_axes.assign(data_span.begin(), data_span.end());
} else {
input_axes.clear();
}

if (input_axes.empty() && noop_with_empty_axes) {
const Tensor* input = ctx->Input<Tensor>(0);
auto* output = ctx->Output(0, input->Shape());
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,5 +589,30 @@ TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) {
#endif
}

TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) {
// Verify this runs
constexpr const ORTCHAR_T* model_uri = ORT_TSTR("testdata/transform/gh_issue_18338.onnx");

SessionOptions session_options;
InferenceSessionWrapper session_object{session_options, GetEnvironment()};

ASSERT_STATUS_OK(session_object.Load(model_uri));
ASSERT_STATUS_OK(session_object.Initialize());

// Scalar shape for input_0 and output
const std::string input_names[] = {"input_0"};
const std::string output_names[] = {"_val_3"};
TensorShape input_shape;
MLFloat16 input_0_data{684.f};

OrtValue input_0;
Tensor::InitOrtValue(DataTypeImpl::GetType<MLFloat16>(), input_shape, &input_0_data, OrtMemoryInfo(), input_0);

std::vector<OrtValue> fetches(1);
RunOptions run_options;
ASSERT_STATUS_OK(session_object.Run(run_options, AsSpan(input_names), AsSpan({input_0}),
AsSpan(output_names), &fetches, 0));
}

} // namespace test
} // namespace onnxruntime
Binary file not shown.
Loading

0 comments on commit 6f863ae

Please sign in to comment.