Skip to content

Commit

Permalink
Convert scalars to 1D to satisfy ML Program requirements. Fixes test …
Browse files Browse the repository at this point in the history
  • Loading branch information
skottmckay committed Jun 25, 2024
1 parent 4743803 commit e4110f9
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 24 deletions.
48 changes: 31 additions & 17 deletions onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,44 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span<c

namespace {
void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
std::optional<gsl::span<const int64_t>> shape) {
std::optional<gsl::span<const int64_t>> shape, bool convert_scalar = false) {
tensor_type.set_datatype(data_type);
if (shape) {
tensor_type.set_rank(shape->size());
for (const auto& dim : *shape) {
if (dim >= 0) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
auto rank = shape->size();
if (convert_scalar && rank == 0) {
// CoreML scalar has shape {1}
tensor_type.set_rank(1);
tensor_type.add_dimensions()->mutable_constant()->set_size(1);
} else {
tensor_type.set_rank(rank);
for (const auto& dim : *shape) {
if (dim >= 0) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
}
}
}
}
}

void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type,
const ONNX_NAMESPACE::TensorShapeProto* shape) {
const ONNX_NAMESPACE::TensorShapeProto* shape, bool convert_scalar = false) {
tensor_type.set_datatype(data_type);
if (shape) {
tensor_type.set_rank(shape->dim_size());
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim.dim_value()));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
auto rank = shape->dim_size();
if (convert_scalar && rank == 0) {
// CoreML scalar has shape {1}
tensor_type.set_rank(1);
tensor_type.add_dimensions()->mutable_constant()->set_size(1);
} else {
tensor_type.set_rank(rank);
for (const auto& dim : shape->dim()) {
if (dim.has_dim_value()) {
tensor_type.add_dimensions()->mutable_constant()->set_size(narrow<int32_t>(dim.dim_value()));
} else {
tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false);
}
}
}
}
Expand Down Expand Up @@ -281,13 +295,13 @@ template MILSpec::Value CreateScalarTensorValue(const int32_t& data);
template MILSpec::Value CreateScalarTensorValue(const std::string& data);
template MILSpec::Value CreateScalarTensorValue(const bool& data);

COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) {
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar) {
MILSpec::NamedValueType nvt;
nvt.set_name(node_arg.Name());
MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype();

SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()),
node_arg.Shape());
node_arg.Shape(), convert_scalar);

return nvt;
}
Expand All @@ -308,7 +322,7 @@ void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& outp
MILSpec::TensorType& tensor_type = *value.mutable_tensortype();

SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()),
output.Shape());
output.Shape(), /*convert_scalar*/ true);
}

void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ template <typename T>
COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data);

/// <summary>Create a NamedValueType from an ONNX tensor NodeArg.</summary>
/// <param name="node_arg">NodeArg to create NamedValueType from.</param>
/// <param name="convert_scalar">If true, scalar shapes are converted to 1D.</param>
/// <remarks>Used to create inputs for the 'main' function in an ML Program.</remarks>
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg);
COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg, bool convert_scalar = false);

/// <summary>
/// Add an input argument to a MILSpec::Operation
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,13 +838,8 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
if (create_ml_program_) {
if (is_input) {
// the model inputs need to be wired up as args to the 'main' function.
auto tensor_value_type = CreateNamedTensorValueType(node_arg);
auto tensor_value_type = CreateNamedTensorValueType(node_arg, /*convert_scalar*/ true);
tensor_value_type.set_name(name);
if (node_arg.Shape()->dim_size() == 0) {
// update shape from {} to {1} (same change we made at the model input level above).
tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1);
tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1);
}

mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type));
} else {
Expand Down

0 comments on commit e4110f9

Please sign in to comment.