From 0cccdd700e552cb70ba4f6ffe0f805011a6e66de Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 19 Jul 2024 21:36:47 +1000 Subject: [PATCH 1/9] Add ConvTranspose - some limitations to simplify the implementation for now - some limitations due to flaky CoreML output Added support for non-contiguous MLMultiArray output as we see that with some unit tests when the CPU-only flag is not set (e.g. innermost dim has min size of 16 but test output only has 8 values). - support only one non-contiguous dim to keep it simple - manually tested as we don't have a setup that can test objective-c code - test code is in model.mm and can be enabled via ifdef if we need to validate any future changes --- .../builders/impl/convtranspose_op_builder.cc | 217 +++++++++++++++++ .../coreml/builders/op_builder_factory.cc | 164 ++++--------- .../coreml/builders/op_builder_factory.h | 1 + .../core/providers/coreml/model/model.mm | 229 +++++++++++++++--- .../providers/xnnpack/nn/conv_transpose.cc | 2 +- .../cpu/nn/conv_transpose_op_test.cc | 8 +- .../apple/coreml_supported_mlprogram_ops.md | 1 + 7 files changed, 469 insertions(+), 153 deletions(-) create mode 100644 onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc new file mode 100644 index 0000000000000..c3f143a362b95 --- /dev/null +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -0,0 +1,217 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" + +using namespace CoreML::Specification; + +namespace onnxruntime { +namespace coreml { + +class ConvTransposeOpBuilder : public BaseOpBuilder { + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override; + + bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, + const logging::Logger& /* logger */) const override; + + bool SupportsMLProgram() const override { return true; } +}; + +Status ConvTransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& /*logger*/) const { +#if defined(COREML_ENABLE_MLPROGRAM) + using namespace CoreML::Specification::MILSpec; // NOLINT + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + const auto& input_name = input_defs[0]->Name(); + + NodeAttrHelper helper(node); + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.conv.conv_transpose + std::unique_ptr op = model_builder.CreateOperation(node, "conv_transpose"); + const auto& op_type = op->type(); + + AddOperationInput(*op, "x", input_name); + AddOperationInput(*op, "weight", input_defs[1]->Name()); + + if (input_defs.size() > 2) { + AddOperationInput(*op, "bias", input_defs[2]->Name()); + } + + // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. + const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; + + // Spec says strides/dilations/pads are optional but reality is they're required for at least the iOS15 target + // which is CoreML5. Due to that we just add everything for simplicity. + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + const auto dilations = helper.Get("dilations", std::vector(num_spatial_dims, 1)); + + AddOperationInput(*op, "strides", model_builder.AddConstant(op_type, "strides", strides)); + AddOperationInput(*op, "dilations", model_builder.AddConstant(op_type, "dilations", dilations)); + + const std::optional groups = helper.GetInt64("group"); + if (groups) { + AddOperationInput(*op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); + } + + // if we can enable output_shape, this code works. see IsOpSupportedImpl for the reason it's disabled. + // const auto output_shape = helper.GetInt64s("output_shape"); + // if (output_shape) { + // AddOperationInput(*op, "output_shape", model_builder.AddConstant(op_type, "output_shape", *output_shape)); + // // these are required despite the spec saying + // AddOperationInput(*op, "pad_type", model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + // std::vector pads(num_spatial_dims * 2, 0); + // AddOperationInput(*op, "pad", model_builder.AddConstant(op_type, "pad", pads)); + //} else { + // AddPadTypeAndPads(*op, model_builder, op_type, helper, num_spatial_dims); + //} + + AddPadTypeAndPads(*op, model_builder, op_type, helper, num_spatial_dims); + + AddOperationOutput(*op, *output_defs[0]); + + model_builder.AddOperation(std::move(op)); +#endif // defined(COREML_ENABLE_MLPROGRAM) + + return Status::OK(); +} + +bool ConvTransposeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (!input_params.create_mlprogram) { + LOGS(logger, VERBOSE) << "ConvTranspose: ML Program required"; + return false; + } + + // ML Program + // - const weight until CoreML7 (iOS17) + // - require constant for now as non-const would be unusual and we rely on the shape of W to be known to validate + // the kernel_shape can be used + // - const bias + // - const pad + // - if auto_pad is same_upper or same_lower the output[i] - (input[i] * strides[i]) must be divisible by 2 + // as the pads must be equally split as there's no upper/lower option in CoreML + // - punting on supporting this for now + // - must be symmetric for CoreML to do the right thing + // - const strides/dilations/groups + // - output_shape CoreML output is inconsistent so disabled for now + // + // NOTE: need to test with/without the COREML_FLAG_USE_CPU_ONLY flag being set to get an idea of how flaky the CoreML + // behaviour is. + // Update /onnxruntime/test/util/default_providers.cc:DefaultCoreMLExecutionProvider to do so + + const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + // requires the rank at least to be known + LOGS(logger, VERBOSE) << "ConvTranspose: failed to get input shape"; + return false; + } + + // for simplicity require weight to be constant + const auto& weight_arg = *input_defs[1]; + const auto& weight_name = input_defs[1]->Name(); + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name); + if (!weight) { + LOGS(logger, VERBOSE) << "ConvTranspose: weight must be constant"; + return false; + } + + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "ConvTranspose: bias must be constant"; + return false; + } + + std::vector weight_shape; + if (!GetShape(weight_arg, weight_shape, logger)) { + // impossible as it's a constant initializer + LOGS(logger, VERBOSE) << "ConvTranspose: failed to get weight shape"; + return false; + } + + int64_t num_spatial_dims = narrow(weight_shape.size()) - 2; + + NodeAttrHelper helper(node); + + // Punt on SAME_UPPER/SAME_LOWER for now. + // We could infer that 'same' -> 'same_upper' based on the CoreML conv spec having 'same' and 'same_lower' but + // need to validate that assertion. + // Additionally, if the pads size is equal, there's no difference between same_upper and same_lower. + // To do that we'd need the 'output_shape' attribute to check against. + // Can add this handling if/when needed. + auto autopad = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (autopad == AutoPadType::SAME_LOWER || autopad == AutoPadType::SAME_UPPER) { + LOGS(logger, VERBOSE) << "ConvTranspose: support for SAME_LOWER/SAME_UPPER is not implemented yet"; + return false; + } else if (autopad == AutoPadType::NOTSET) { + // CoreML output is inconsistent if pads are asymmetric. + // CPU works. Other devices don't seem to (at least on macOS). + auto onnx_pads = *helper.GetInt64s("pads"); // 'pads' are requred if auto_pad is NOTSET + const auto pad_value = onnx_pads[0]; + if (!std::all_of(onnx_pads.begin() + 1, onnx_pads.end(), + [pad_value](auto value) { return value == pad_value; })) { + LOGS(logger, VERBOSE) << "ConvTranspose: pads must be symmetric for CoreML to return consistent results"; + return false; + } + } + + // there's no input to specify a kernel shape in CoreML. + // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. + auto kernel_shape = helper.GetInt64s("kernel_shape"); + if (kernel_shape) { + bool valid = true; + + if (static_cast(kernel_shape->size()) == num_spatial_dims) { + for (int i = 0; i < num_spatial_dims; ++i) { + // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter. + if ((*kernel_shape)[i] != weight_shape[i + 2]) { + valid = false; + break; + } + } + } else { + valid = false; + } + + if (!valid) { + LOGS(logger, VERBOSE) << "ConvTranspose: kernel_shape attribute does not match the weight shape"; + return false; + } + } + + // In theory this can be supported, but running with COREML_FLAG_USE_CPU_ONLY produces output that doesn't match + // ONNX. Running without that flag produces the expected output. Madness... + auto output_shape = helper.GetInt64s("output_shape"); + if (output_shape) { + // there is an output_shape input, but the padding seems to be different so results don't + LOGS(logger, VERBOSE) << "ConvTranspose: output_shape is not supported as the CoreML output is inconsistent"; + return false; + } + + // output_padding, if specified, must be the default value of all zeros as there's no equivalent in CoreML. + auto output_padding = helper.GetInt64s("output_padding"); + if (output_padding && + std::any_of(output_padding->begin(), output_padding->end(), [](auto value) { return value != 0; })) { + LOGS(logger, VERBOSE) << "ConvTranspose: output_padding is not supported"; + return false; + } + + return true; +} + +void CreateConvTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc index 2c06659852134..a2725c227e78c 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.cc @@ -15,120 +15,56 @@ namespace coreml { static OpBuilderRegistrations CreateOpBuilderRegistrations() { OpBuilderRegistrations op_registrations; - { // Add/Mul/Pow/Sub/Div - CreateBinaryOpBuilder("Add", op_registrations); - CreateBinaryOpBuilder("Mul", op_registrations); - CreateBinaryOpBuilder("Pow", op_registrations); - CreateBinaryOpBuilder("Sub", op_registrations); - CreateBinaryOpBuilder("Div", op_registrations); - } - - { // Activations - CreateActivationOpBuilder("Sigmoid", op_registrations); - CreateActivationOpBuilder("Tanh", op_registrations); - CreateActivationOpBuilder("Relu", op_registrations); - CreateActivationOpBuilder("PRelu", op_registrations); - CreateActivationOpBuilder("LeakyRelu", op_registrations); - } - - { // Transpose - CreateTransposeOpBuilder("Transpose", op_registrations); - } - - { // Conv - CreateConvOpBuilder("Conv", op_registrations); - } - - { // Batch Normalization - CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations); - } - - { // Reshape - CreateReshapeOpBuilder("Reshape", op_registrations); - } - - { // DepthToSpace - CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations); - } - - { // Pool - CreatePoolOpBuilder("GlobalAveragePool", op_registrations); - CreatePoolOpBuilder("GlobalMaxPool", op_registrations); - CreatePoolOpBuilder("AveragePool", op_registrations); - CreatePoolOpBuilder("MaxPool", op_registrations); - } - - { // Concat - CreateConcatOpBuilder("Concat", op_registrations); - } - - { // Resize - CreateResizeOpBuilder("Resize", op_registrations); - } - - { // Gemm/MatMul - CreateGemmOpBuilder("Gemm", op_registrations); - CreateGemmOpBuilder("MatMul", op_registrations); - } - - { // Clip - CreateClipOpBuilder("Clip", op_registrations); - } - - { // Squeeze - CreateSqueezeOpBuilder("Squeeze", op_registrations); - } - - { // ArgMax - CreateArgMaxOpBuilder("ArgMax", op_registrations); - } - - { // Cast - CreateCastOpBuilder("Cast", op_registrations); - } - - { // Flatten - CreateFlattenOpBuilder("Flatten", op_registrations); - } - - { // LRN - CreateLRNOpBuilder("LRN", op_registrations); - } - - { // Pad - CreatePadOpBuilder("Pad", op_registrations); - } - - { // Unary - CreateUnaryOpBuilder("Sqrt", op_registrations); - CreateUnaryOpBuilder("Reciprocal", op_registrations); - } - - { // Reduction - // ReduceMean is used in layer normalization which seems to be problematic in Python tests. - CreateReductionOpBuilder("ReduceMean", op_registrations); - CreateReductionOpBuilder("ReduceSum", op_registrations); - } - - { // Shape - CreateShapeOpBuilder("Shape", op_registrations); - } - - { // Gather - CreateGatherOpBuilder("Gather", op_registrations); - } - - { // Slice - CreateSliceOpBuilder("Slice", op_registrations); - } - - { // Softmax - CreateSoftmaxOpBuilder("Softmax", op_registrations); - } - - { // Split - CreateSplitOpBuilder("Split", op_registrations); - } + // Unary ops + CreateUnaryOpBuilder("Sqrt", op_registrations); + CreateUnaryOpBuilder("Reciprocal", op_registrations); + + // Binary elementwise ops + CreateBinaryOpBuilder("Add", op_registrations); + CreateBinaryOpBuilder("Mul", op_registrations); + CreateBinaryOpBuilder("Pow", op_registrations); + CreateBinaryOpBuilder("Sub", op_registrations); + CreateBinaryOpBuilder("Div", op_registrations); + + // Activations + CreateActivationOpBuilder("Sigmoid", op_registrations); + CreateActivationOpBuilder("Tanh", op_registrations); + CreateActivationOpBuilder("Relu", op_registrations); + CreateActivationOpBuilder("PRelu", op_registrations); + CreateActivationOpBuilder("LeakyRelu", op_registrations); + + // Pooling ops + CreatePoolOpBuilder("GlobalAveragePool", op_registrations); + CreatePoolOpBuilder("GlobalMaxPool", op_registrations); + CreatePoolOpBuilder("AveragePool", op_registrations); + CreatePoolOpBuilder("MaxPool", op_registrations); + + // Reduction ops + CreateReductionOpBuilder("ReduceMean", op_registrations); + CreateReductionOpBuilder("ReduceSum", op_registrations); + + CreateArgMaxOpBuilder("ArgMax", op_registrations); + CreateBatchNormalizationOpBuilder("BatchNormalization", op_registrations); + CreateCastOpBuilder("Cast", op_registrations); + CreateClipOpBuilder("Clip", op_registrations); + CreateConcatOpBuilder("Concat", op_registrations); + CreateConvOpBuilder("Conv", op_registrations); + CreateConvTransposeOpBuilder("ConvTranspose", op_registrations); + CreateDepthToSpaceOpBuilder("DepthToSpace", op_registrations); + CreateFlattenOpBuilder("Flatten", op_registrations); + CreateGatherOpBuilder("Gather", op_registrations); + CreateGemmOpBuilder("Gemm", op_registrations); + CreateLRNOpBuilder("LRN", op_registrations); + CreateGemmOpBuilder("MatMul", op_registrations); + CreatePadOpBuilder("Pad", op_registrations); + CreateReshapeOpBuilder("Reshape", op_registrations); + CreateResizeOpBuilder("Resize", op_registrations); + CreateShapeOpBuilder("Shape", op_registrations); + CreateSliceOpBuilder("Slice", op_registrations); + CreateSplitOpBuilder("Split", op_registrations); + CreateSoftmaxOpBuilder("Softmax", op_registrations); + CreateSqueezeOpBuilder("Squeeze", op_registrations); + CreateTransposeOpBuilder("Transpose", op_registrations); return op_registrations; } diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index 6469b4cefa5ea..6ad590f2c73e3 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -24,6 +24,7 @@ void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConvTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 1d506099b4367..e29448f24a6a3 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -26,6 +26,13 @@ #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" +// manually enable to test logic for handling non-contiguous MLMultiArray as we don't have a unit test setup +// that can hit that. +// #define TEST_MLMULTIARRAY_HANDLING +#ifdef TEST_MLMULTIARRAY_HANDLING +#include +#endif + // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this asm(".linker_option \"-framework\", \"CoreML\""); @@ -174,51 +181,195 @@ Status CreateInputFeatureProvider(const std::unordered_map= *block_size, "Logic error calculating copy info"); + ORT_ENFORCE(*stride * *num_blocks == total_elems, "Logic error calculating copy info"); + + return Status::OK(); } +#ifdef TEST_MLMULTIARRAY_HANDLING +void ValidateGetInfo(MLMultiArray* array, + int64_t expected_num_blocks, int64_t expected_block_size, int64_t expected_stride, bool valid) { + + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + auto status = GetMLMultiArrayCopyInfo(array, &num_blocks, &block_size, &stride); + + if (!valid) { + assert(!status.IsOK()); + return; + } + + assert(status.IsOK()); + assert(num_blocks == expected_num_blocks); + assert(block_size == expected_block_size); + assert(stride == expected_stride); +} + +void ValidateMLMultiArrayHandling() { + void* data = reinterpret_cast(0xfeedf00d); + + // dim -1 with stride + { + NSArray *shape = @[@1, @1, @8, @8]; + NSArray *strides = @[@128, @128, @16, @2]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) {} + error:nil]; + ValidateGetInfo(array, 64, 1, 2, true); + } + + // dim -2 with stride + { + NSArray *shape = @[@1, @1, @8, @8]; + NSArray *strides = @[@128, @128, @16, @1]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) {} + error:nil]; + ValidateGetInfo(array, 8, 8, 16, true); + } + + // dim -3 with stride + { + NSArray *shape = @[@1, @2, @4, @4]; + NSArray *strides = @[@48, @24, @4, @1]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) {} + error:nil]; + + ValidateGetInfo(array, 2, 16, 24, true); + } + + // two non-contiguous dims + { + // dim + NSArray *shape = @[@1, @2, @4, @4]; + NSArray *strides = @[@96, @48, @8, @1]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) {} + error:nil]; + + ValidateGetInfo(array, 0, 0, 0, false); + } +} +#endif + Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buffer, - const MLMultiArray* array_info, - const OnnxTensorInfo* tensor_info, - const std::optional mlmultiarray_buffer_size) { + const MLMultiArray* array, + const int64_t num_blocks, const int64_t block_size, const int64_t stride, + const OnnxTensorInfo* tensor_info) { if (mlmultiarray_buffer == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "mlmultiarray_buffer has no data"); } - const size_t num_elements = array_info.count; + // total including non-contiguous space + + int64_t array_total_elements = [array.strides[0] longLongValue] * [array.shape[0] longLongValue]; + const int64_t num_elements = array.count; + + ORT_RETURN_IF(array_total_elements != num_blocks * stride || + num_elements != num_blocks * block_size, + "MLMultiArray size does not match the copy info"); + const auto onnx_data_type = tensor_info->data_type; switch (onnx_data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - const auto output_data_byte_size = num_elements * sizeof(float); - ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, - "CoreML output buffer size and expected output size differ"); - memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + + const auto *src_buffer = static_cast(mlmultiarray_buffer); + auto *dst_buffer = static_cast(tensor_buffer); + const auto block_byte_size = block_size * sizeof(float); + + for (int64_t idx = 0; idx < num_blocks; ++idx) { + memcpy(dst_buffer, src_buffer, block_byte_size); + src_buffer += stride; + dst_buffer += block_size; + } break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto output_data_byte_size = num_elements * sizeof(int32_t); - ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == output_data_byte_size, - "CoreML output buffer size and expected output size differ"); - memcpy(tensor_buffer, mlmultiarray_buffer, output_data_byte_size); + const auto *src_buffer = static_cast(mlmultiarray_buffer); + auto *dst_buffer = static_cast(tensor_buffer); + const auto block_byte_size = block_size * sizeof(int32_t); + + for (int64_t idx = 0; idx < num_blocks; ++idx) { + memcpy(dst_buffer, src_buffer, block_byte_size); + src_buffer += stride; + dst_buffer += block_size; + } + break; } // For this case, since Coreml Spec only uses int32 for model output while onnx provides // int64 for model output data type. We are doing a type casting (int32 -> int64) here // when copying the model to ORT case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_NOT(array_info.dataType == MLMultiArrayDataTypeInt32, - "CoreML output data type is not MLMultiArrayDataTypeInt32"); - ORT_RETURN_IF_NOT(!mlmultiarray_buffer_size || mlmultiarray_buffer_size == num_elements * sizeof(int32_t), - "CoreML output buffer size and expected output size differ"); - const auto model_output_span = gsl::span{static_cast(mlmultiarray_buffer), num_elements}; - const auto output_span = gsl::span{static_cast(tensor_buffer), num_elements}; - std::transform(model_output_span.begin(), model_output_span.end(), output_span.begin(), - [](int32_t v) { return static_cast(v); }); + ORT_RETURN_IF(array.dataType != MLMultiArrayDataTypeInt32, + "CoreML output data type is not MLMultiArrayDataTypeInt32"); + + const int32_t *src_buffer = static_cast(mlmultiarray_buffer); + int64_t *dst_buffer = static_cast(tensor_buffer); + + for (int64_t idx = 0; idx < num_blocks; ++idx) { + auto input_span = gsl::span{src_buffer, static_cast(block_size)}; + auto output_span = gsl::span{dst_buffer, static_cast(block_size)}; + std::transform(input_span.begin(), input_span.end(), output_span.begin(), + [](int32_t v) { return static_cast(v); }); + + src_buffer += stride; + dst_buffer += block_size; + } break; } default: @@ -250,8 +401,7 @@ - (void)dealloc; - (Status)loadModel API_AVAILABLE_COREML3; - (Status)predict:(const std::unordered_map&)inputs outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&) - get_output_tensor_mutable_raw_data_fn + getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn API_AVAILABLE_COREML3; @property(nullable) MLModel* model API_AVAILABLE_COREML3; @@ -397,21 +547,27 @@ - (Status)predict:(const std::unordered_map&)inputs ") do not match"); } - ORT_RETURN_IF_NOT(IsArrayContiguous(data), - "Non-contiguous output MLMultiArray is not currently supported"); + // support a non-contiguous array, provided only one dimension is not contiguous + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + + ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, &num_blocks, &block_size, &stride)); + __block Status copy_status; const auto* tensor_info = &output_tensor_info; // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions if (@available(macOS 12.3, iOS 15.4, *)) { [data getBytesWithHandler:^(const void* bytes, NSInteger size) { - copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, tensor_info, size); + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, + num_blocks, block_size, stride, tensor_info); }]; } else { - // disable size check as old API does not return buffer length - copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, tensor_info, std::nullopt); + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, + num_blocks, block_size, stride, tensor_info); } - if (!copy_status.IsOK()) - return copy_status; + + ORT_RETURN_IF_ERROR(copy_status); } } } @@ -508,6 +664,11 @@ Status Predict(const std::unordered_map& inputs, Model::~Model() {} Status Model::LoadModel() { + // arbitrary place to run this when manually enabled for temporary testing +#ifdef TEST_MLMULTIARRAY_HANDLING + ValidateMLMultiArrayHandling(); +#endif + return execution_->LoadModel(); } diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc index c136385f12476..01c8119fea79d 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_transpose.cc @@ -24,7 +24,7 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr const auto rank = orig_shape.NumDimensions(); if (conv_transpose_attrs_.group > 1) { - // Xnnpack [G, Oc, H, W Ic/G] + // Xnnpack [G, Oc, H, W, Ic/G] // (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678) if (rank == 4) { // split C (dim 0) into {group, C/group} diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 81191e9b48c3c..2bf53ce5b5986 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -27,7 +27,7 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, const vector>& input_shapes, const std::initializer_list& expected_output, const vector& expected_output_shape, - bool is_filter_initializer = false, + bool is_weight_and_bias_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider}) { @@ -58,10 +58,10 @@ void TestConvTransposeOpInitializer(const ConvTransposeOpAttributes& attributes, } ORT_ENFORCE(inputs.size() <= 3, "Our name array is only setup to handle 3 inputs"); - const char* szNames[] = {"X", "W", "B"}; - bool isInitializers[] = {false, is_filter_initializer, false}; + const char* input_names[] = {"X", "W", "B"}; + bool is_initializers[] = {false, is_weight_and_bias_initializer, is_weight_and_bias_initializer}; for (size_t i = 0; i < inputs.size(); i++) { - test.AddInput(szNames[i], input_shapes[i], inputs[i], isInitializers[i]); + test.AddInput(input_names[i], input_shapes[i], inputs[i], is_initializers[i]); } test.AddOutput("Y", expected_output_shape, expected_output); diff --git a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md index 1bbb933f66ba4..60a2e2ec929b6 100644 --- a/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md +++ b/tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md @@ -7,6 +7,7 @@ Keep in sync with doco generated from /docs/execution-providers/CoreML-Execution |ai.onnx:AveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| |ai.onnx:Clip|| |ai.onnx:Conv|Only 1D/2D Conv is supported.
Bias if provided must be constant.| +|ai.onnx:ConvTranspose|Weight and bias must be constant.
padding_type of SAME_UPPER/SAME_LOWER is not supported.
kernel_shape must have default values.
output_shape is not supported.
output_padding must have default values.| |ai.onnx:Div|| |ai.onnx:Gemm|Input B must be constant.| |ai.onnx:GlobalAveragePool|Only 2D Pool is supported currently. 3D and 5D support can be added if needed.| From fcb90c0aeefc5296deac7b35d0aeca3c4f497196 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 19 Jul 2024 22:54:10 +1000 Subject: [PATCH 2/9] Fix build --- .../providers/coreml/builders/impl/convtranspose_op_builder.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc index c3f143a362b95..e571fca23c5fb 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -25,7 +25,8 @@ class ConvTransposeOpBuilder : public BaseOpBuilder { bool SupportsMLProgram() const override { return true; } }; -Status ConvTransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, +Status ConvTransposeOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder, + [[maybe_unused]] const Node& node, const logging::Logger& /*logger*/) const { #if defined(COREML_ENABLE_MLPROGRAM) using namespace CoreML::Specification::MILSpec; // NOLINT From 9d8049759c1f5c917220d333b3cf58962c445d87 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 09:54:08 +1000 Subject: [PATCH 3/9] Move test of MLMultiArray handling to a unit test. --- .../core/providers/coreml/model/model.h | 11 + .../core/providers/coreml/model/model.mm | 205 +++++------------- .../test/providers/coreml/utils_test.mm | 110 ++++++++++ 3 files changed, 175 insertions(+), 151 deletions(-) create mode 100644 onnxruntime/test/providers/coreml/utils_test.mm diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index c4c3b38bba516..57d4b23eeb234 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -13,6 +13,13 @@ #include "core/common/status.h" #include "core/platform/ort_mutex.h" +#if defined(__APPLE__) +#ifdef __OBJC__ +@class MLMultiArray; +#else +typedef struct objc_object MLMultiArray; +#endif + namespace onnxruntime { namespace coreml { @@ -32,6 +39,10 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; +// helper function that we unit test +Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, int64_t* block_size, int64_t* stride); +#endif + class Model { public: Model(const std::string& path, diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index e29448f24a6a3..33c5545cbb920 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -26,13 +26,6 @@ #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" -// manually enable to test logic for handling non-contiguous MLMultiArray as we don't have a unit test setup -// that can hit that. -// #define TEST_MLMULTIARRAY_HANDLING -#ifdef TEST_MLMULTIARRAY_HANDLING -#include -#endif - // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this asm(".linker_option \"-framework\", \"CoreML\""); @@ -181,131 +174,6 @@ Status CreateInputFeatureProvider(const std::unordered_map= *block_size, "Logic error calculating copy info"); - ORT_ENFORCE(*stride * *num_blocks == total_elems, "Logic error calculating copy info"); - - return Status::OK(); -} - -#ifdef TEST_MLMULTIARRAY_HANDLING -void ValidateGetInfo(MLMultiArray* array, - int64_t expected_num_blocks, int64_t expected_block_size, int64_t expected_stride, bool valid) { - - int64_t num_blocks = 0; - int64_t block_size = 0; - int64_t stride = 0; - auto status = GetMLMultiArrayCopyInfo(array, &num_blocks, &block_size, &stride); - - if (!valid) { - assert(!status.IsOK()); - return; - } - - assert(status.IsOK()); - assert(num_blocks == expected_num_blocks); - assert(block_size == expected_block_size); - assert(stride == expected_stride); -} - -void ValidateMLMultiArrayHandling() { - void* data = reinterpret_cast(0xfeedf00d); - - // dim -1 with stride - { - NSArray *shape = @[@1, @1, @8, @8]; - NSArray *strides = @[@128, @128, @16, @2]; - - auto* array = [[MLMultiArray alloc] initWithDataPointer:data - shape:shape - dataType:MLMultiArrayDataTypeInt32 - strides:strides - deallocator:^(void* /* bytes */) {} - error:nil]; - ValidateGetInfo(array, 64, 1, 2, true); - } - - // dim -2 with stride - { - NSArray *shape = @[@1, @1, @8, @8]; - NSArray *strides = @[@128, @128, @16, @1]; - - auto* array = [[MLMultiArray alloc] initWithDataPointer:data - shape:shape - dataType:MLMultiArrayDataTypeInt32 - strides:strides - deallocator:^(void* /* bytes */) {} - error:nil]; - ValidateGetInfo(array, 8, 8, 16, true); - } - - // dim -3 with stride - { - NSArray *shape = @[@1, @2, @4, @4]; - NSArray *strides = @[@48, @24, @4, @1]; - - auto* array = [[MLMultiArray alloc] initWithDataPointer:data - shape:shape - dataType:MLMultiArrayDataTypeInt32 - strides:strides - deallocator:^(void* /* bytes */) {} - error:nil]; - - ValidateGetInfo(array, 2, 16, 24, true); - } - - // two non-contiguous dims - { - // dim - NSArray *shape = @[@1, @2, @4, @4]; - NSArray *strides = @[@96, @48, @8, @1]; - - auto* array = [[MLMultiArray alloc] initWithDataPointer:data - shape:shape - dataType:MLMultiArrayDataTypeInt32 - strides:strides - deallocator:^(void* /* bytes */) {} - error:nil]; - - ValidateGetInfo(array, 0, 0, 0, false); - } -} -#endif - Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buffer, const MLMultiArray* array, const int64_t num_blocks, const int64_t block_size, const int64_t stride, @@ -320,33 +188,32 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff const int64_t num_elements = array.count; ORT_RETURN_IF(array_total_elements != num_blocks * stride || - num_elements != num_blocks * block_size, + num_elements != num_blocks * block_size, "MLMultiArray size does not match the copy info"); const auto onnx_data_type = tensor_info->data_type; switch (onnx_data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - - const auto *src_buffer = static_cast(mlmultiarray_buffer); - auto *dst_buffer = static_cast(tensor_buffer); + const auto* src_buffer = static_cast(mlmultiarray_buffer); + auto* dst_buffer = static_cast(tensor_buffer); const auto block_byte_size = block_size * sizeof(float); for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; + memcpy(dst_buffer, src_buffer, block_byte_size); + src_buffer += stride; + dst_buffer += block_size; } break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - const auto *src_buffer = static_cast(mlmultiarray_buffer); - auto *dst_buffer = static_cast(tensor_buffer); + const auto* src_buffer = static_cast(mlmultiarray_buffer); + auto* dst_buffer = static_cast(tensor_buffer); const auto block_byte_size = block_size * sizeof(int32_t); for (int64_t idx = 0; idx < num_blocks; ++idx) { - memcpy(dst_buffer, src_buffer, block_byte_size); - src_buffer += stride; - dst_buffer += block_size; + memcpy(dst_buffer, src_buffer, block_byte_size); + src_buffer += stride; + dst_buffer += block_size; } break; @@ -358,8 +225,8 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff ORT_RETURN_IF(array.dataType != MLMultiArrayDataTypeInt32, "CoreML output data type is not MLMultiArrayDataTypeInt32"); - const int32_t *src_buffer = static_cast(mlmultiarray_buffer); - int64_t *dst_buffer = static_cast(tensor_buffer); + const int32_t* src_buffer = static_cast(mlmultiarray_buffer); + int64_t* dst_buffer = static_cast(tensor_buffer); for (int64_t idx = 0; idx < num_blocks; ++idx) { auto input_span = gsl::span{src_buffer, static_cast(block_size)}; @@ -587,6 +454,47 @@ - (Status)predict:(const std::unordered_map&)inputs namespace onnxruntime { namespace coreml { +Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, int64_t* block_size, int64_t* stride) { + const auto* shape = array.shape; + const auto rank = shape.count; + + int64_t array_total_elements = [array.strides[0] longLongValue] * [shape[0] longLongValue]; + + int64_t data_elems = 1; // actual values + int64_t total_elems = 1; // elems including empty slots if non-contiguous + for (unsigned long i = 1; i <= rank; i++) { + int64_t this_stride = [array.strides[rank - i] longLongValue]; + if (this_stride != total_elems) { + // non-contigous if we have to move more than batch_elems for each entry + if (*block_size != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Multiple non-contiguous dimensions in MLMultiArray are not supported."); + } + + *block_size = data_elems; + *stride = this_stride; + } + + const auto elems_this_dim = [shape[rank - i] longLongValue]; + data_elems *= elems_this_dim; + total_elems = elems_this_dim * this_stride; + } + + if (*block_size == 0) { + // contiguous + *block_size = data_elems; + *stride = array_total_elements; + } + + *num_blocks = data_elems / *block_size; + + ORT_ENFORCE(array_total_elements == total_elems, "Logic error calculating copy info"); + ORT_ENFORCE(*stride >= *block_size, "Logic error calculating copy info"); + ORT_ENFORCE(*stride * *num_blocks == total_elems, "Logic error calculating copy info"); + + return Status::OK(); +} + // Internal Execution class // This class will bridge Model (c++) with CoreMLExecution (objective c++) class Execution { @@ -664,11 +572,6 @@ Status Predict(const std::unordered_map& inputs, Model::~Model() {} Status Model::LoadModel() { - // arbitrary place to run this when manually enabled for temporary testing -#ifdef TEST_MLMULTIARRAY_HANDLING - ValidateMLMultiArrayHandling(); -#endif - return execution_->LoadModel(); } diff --git a/onnxruntime/test/providers/coreml/utils_test.mm b/onnxruntime/test/providers/coreml/utils_test.mm new file mode 100644 index 0000000000000..2ea3d229c0d8b --- /dev/null +++ b/onnxruntime/test/providers/coreml/utils_test.mm @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(__APPLE__) +#import + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/providers/coreml/model/model.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { +namespace { +auto ValidateGetInfo(MLMultiArray* array, + int64_t expected_num_blocks, int64_t expected_block_size, int64_t expected_stride, + bool expect_valid) { + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + auto status = coreml::GetMLMultiArrayCopyInfo(array, &num_blocks, &block_size, &stride); + + if (!expect_valid) { + ASSERT_STATUS_NOT_OK(status); + return; + } + + ASSERT_STATUS_OK(status); + ASSERT_EQ(num_blocks, expected_num_blocks); + ASSERT_EQ(block_size, expected_block_size); + ASSERT_EQ(stride, expected_stride); +} +} // namespace + +TEST(CoreMLUtils, GetMLMultiArrayReadInfo) { + // fake pointer. we don't read any data but initWithDataPointer requires a non-null address + void* data = reinterpret_cast(0xfeedf00d); + + // a dim is non-contiguous if the stride is > the total number of elements in its inner dimensions + + // dim -1 with non-contiguous data. 1 element (as it's the inner-most dimension) but the stride is 2. + { + NSArray* shape = @[ @1, @1, @8, @8 ]; + NSArray* strides = @[ @128, @128, @16, @2 ]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) { + } + error:nil]; + ValidateGetInfo(array, 64, 1, 2, true); + } + + // dim -2 with non-contiguous data. 8 elements in the inner dimension but the stride is 16. + { + NSArray* shape = @[ @1, @1, @8, @8 ]; + NSArray* strides = @[ @128, @128, @16, @1 ]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) { + } + error:nil]; + ValidateGetInfo(array, 8, 8, 16, true); + } + + // dim -3 with non-contiguous data. 16 elements in the innder dimensions but stride is 24. + { + NSArray* shape = @[ @1, @2, @4, @4 ]; + NSArray* strides = @[ @48, @24, @4, @1 ]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) { + } + error:nil]; + + ValidateGetInfo(array, 2, 16, 24, true); + } + + // two non-contiguous dims (dim -2 and dim -3) + // dim -2 has 4 elements in the inner dimension and stride of 8 + // dim -3 has 32 elements in the inner dimensions (we need to include the empty elements from the non-contiguous data + // in dim -2) and stride of 48 + { + // dim + NSArray* shape = @[ @1, @2, @4, @4 ]; + NSArray* strides = @[ @96, @48, @8, @1 ]; + + auto* array = [[MLMultiArray alloc] initWithDataPointer:data + shape:shape + dataType:MLMultiArrayDataTypeInt32 + strides:strides + deallocator:^(void* /* bytes */) { + } + error:nil]; + + ValidateGetInfo(array, 0, 0, 0, false); + } +} +} // namespace test +} // namespace onnxruntime +#endif From ca70df252c21a9522d46833483a1a7597bb02f1d Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 10:06:45 +1000 Subject: [PATCH 4/9] Add naming update from Resize changes --- .../core/providers/coreml/builders/impl/resize_op_builder.cc | 4 ++-- .../nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc | 4 ++-- onnxruntime/core/providers/utils.cc | 2 +- onnxruntime/core/providers/utils.h | 2 +- onnxruntime/core/providers/xnnpack/tensor/resize.cc | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 65b5c17f2c6a6..7ff66e4a79e37 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -427,13 +427,13 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa auto h_in = input_shape[input_rank - 2]; auto w_in = input_shape[input_rank - 1]; - if (!utils::IsScalingByAFactorOfN(h_in, scale_h)) { + if (!utils::ReciprocalIsAFactorOfN(h_in, scale_h)) { LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_h << " is not a factor of input height: " << h_in; return false; } - if (!utils::IsScalingByAFactorOfN(w_in, scale_w)) { + if (!utils::ReciprocalIsAFactorOfN(w_in, scale_w)) { LOGS(logger, VERBOSE) << "Resize: downsampling scale " << scale_w << " is not a factor of input width: " << w_in; return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index ef27f6c942f44..44403010c936c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -274,8 +274,8 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const N return false; } - if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || - !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + if (!utils::ReciprocalIsAFactorOfN(h_in, scale_h) || + !utils::ReciprocalIsAFactorOfN(w_in, scale_w)) { LOGS_DEFAULT(VERBOSE) << "Input size must be evenly divisible by output size when downsampling"; return false; } diff --git a/onnxruntime/core/providers/utils.cc b/onnxruntime/core/providers/utils.cc index 747b09e42aa21..2725af95e0959 100644 --- a/onnxruntime/core/providers/utils.cc +++ b/onnxruntime/core/providers/utils.cc @@ -24,7 +24,7 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& } #endif -bool IsScalingByAFactorOfN(int64_t n, float scale) { +bool ReciprocalIsAFactorOfN(int64_t n, float scale) { bool is_factor = false; if (scale > 0.f && scale < 1.f) { const double factor = 1.0 / scale; diff --git a/onnxruntime/core/providers/utils.h b/onnxruntime/core/providers/utils.h index 9ea8496a02f85..cfd71d9b838b3 100644 --- a/onnxruntime/core/providers/utils.h +++ b/onnxruntime/core/providers/utils.h @@ -19,6 +19,6 @@ common::Status OutputOptionalWithoutDataHelper(const ONNX_NAMESPACE::TypeProto& /// Check if the reciprocal of 'scale' is a factor of 'n'. /// e.g. a scale of 0.5 is 1/2, the reciprocal is 2, and 2 is a factor of any even number. /// -bool IsScalingByAFactorOfN(int64_t n, float scale); +bool ReciprocalIsAFactorOfN(int64_t n, float scale); } // namespace utils } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index c752b5f849808..cf874796ba169 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -85,8 +85,8 @@ bool Resize::IsOnnxNodeSupported(const NodeUnit& node_unit, float scale_h = scales[2]; float scale_w = scales[3]; - if (!utils::IsScalingByAFactorOfN(h_in, scale_h) || - !utils::IsScalingByAFactorOfN(w_in, scale_w)) { + if (!utils::ReciprocalIsAFactorOfN(h_in, scale_h) || + !utils::ReciprocalIsAFactorOfN(w_in, scale_w)) { break; } } From 3bd6db30e2e71dc65f1e31806128d40b85fab990 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 11:41:50 +1000 Subject: [PATCH 5/9] fix ifdefs --- onnxruntime/core/providers/coreml/model/model.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 57d4b23eeb234..5bce7cf996e4e 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -19,6 +19,7 @@ #else typedef struct objc_object MLMultiArray; #endif +#endif namespace onnxruntime { namespace coreml { @@ -39,6 +40,7 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; +#if defined(__APPLE__) // helper function that we unit test Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, int64_t* block_size, int64_t* stride); #endif From 99211879a95029adb3025bbc0b46d8dbc7cd7a67 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 12:29:51 +1000 Subject: [PATCH 6/9] Apply suggestions from code review Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- .../coreml/builders/impl/convtranspose_op_builder.cc | 4 ++-- onnxruntime/core/providers/coreml/model/model.mm | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc index e571fca23c5fb..db33240b8a7cc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -67,7 +67,7 @@ Status ConvTransposeOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuild // const auto output_shape = helper.GetInt64s("output_shape"); // if (output_shape) { // AddOperationInput(*op, "output_shape", model_builder.AddConstant(op_type, "output_shape", *output_shape)); - // // these are required despite the spec saying + // // these are required despite the spec saying otherwise // AddOperationInput(*op, "pad_type", model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); // std::vector pads(num_spatial_dims * 2, 0); // AddOperationInput(*op, "pad", model_builder.AddConstant(op_type, "pad", pads)); @@ -156,7 +156,7 @@ bool ConvTransposeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilder } else if (autopad == AutoPadType::NOTSET) { // CoreML output is inconsistent if pads are asymmetric. // CPU works. Other devices don't seem to (at least on macOS). - auto onnx_pads = *helper.GetInt64s("pads"); // 'pads' are requred if auto_pad is NOTSET + auto onnx_pads = *helper.GetInt64s("pads"); // 'pads' are required if auto_pad is NOTSET const auto pad_value = onnx_pads[0]; if (!std::all_of(onnx_pads.begin() + 1, onnx_pads.end(), [pad_value](auto value) { return value == pad_value; })) { diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 33c5545cbb920..194b5225d7f4f 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -465,7 +465,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, i for (unsigned long i = 1; i <= rank; i++) { int64_t this_stride = [array.strides[rank - i] longLongValue]; if (this_stride != total_elems) { - // non-contigous if we have to move more than batch_elems for each entry + // non-contiguous if we have to move more than batch_elems for each entry if (*block_size != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Multiple non-contiguous dimensions in MLMultiArray are not supported."); From d866225bc3dd4d43296f6e76bbb6b26a0c5734b8 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 12:42:39 +1000 Subject: [PATCH 7/9] Address PR comments --- .../builders/impl/convtranspose_op_builder.cc | 10 ++++---- .../core/providers/coreml/model/model.h | 2 +- .../core/providers/coreml/model/model.mm | 23 ++++++++++--------- .../test/providers/coreml/utils_test.mm | 2 +- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc index db33240b8a7cc..eca996597f76d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -106,7 +106,7 @@ bool ConvTransposeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilder // - output_shape CoreML output is inconsistent so disabled for now // // NOTE: need to test with/without the COREML_FLAG_USE_CPU_ONLY flag being set to get an idea of how flaky the CoreML - // behaviour is. + // behavior is. // Update /onnxruntime/test/util/default_providers.cc:DefaultCoreMLExecutionProvider to do so const auto& input_defs = node.InputDefs(); @@ -154,13 +154,14 @@ bool ConvTransposeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilder LOGS(logger, VERBOSE) << "ConvTranspose: support for SAME_LOWER/SAME_UPPER is not implemented yet"; return false; } else if (autopad == AutoPadType::NOTSET) { - // CoreML output is inconsistent if pads are asymmetric. - // CPU works. Other devices don't seem to (at least on macOS). + // CoreML output is inconsistent between CPU_ONLY and ALL if the pads aren't all the same value. + // CPU matches the expected output, but other devices don't seem to (at least on macOS). auto onnx_pads = *helper.GetInt64s("pads"); // 'pads' are required if auto_pad is NOTSET const auto pad_value = onnx_pads[0]; if (!std::all_of(onnx_pads.begin() + 1, onnx_pads.end(), [pad_value](auto value) { return value == pad_value; })) { - LOGS(logger, VERBOSE) << "ConvTranspose: pads must be symmetric for CoreML to return consistent results"; + LOGS(logger, VERBOSE) << "ConvTranspose: all pad values must be the same for CoreML to return " + "consistent results"; return false; } } @@ -193,7 +194,6 @@ bool ConvTransposeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilder // ONNX. Running without that flag produces the expected output. Madness... auto output_shape = helper.GetInt64s("output_shape"); if (output_shape) { - // there is an output_shape input, but the padding seems to be different so results don't LOGS(logger, VERBOSE) << "ConvTranspose: output_shape is not supported as the CoreML output is inconsistent"; return false; } diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 5bce7cf996e4e..5ea0521d973fa 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -42,7 +42,7 @@ using GetOutputTensorMutableRawDataFn = std::function&)inputs int64_t block_size = 0; int64_t stride = 0; - ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, &num_blocks, &block_size, &stride)); + ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); __block Status copy_status; const auto* tensor_info = &output_tensor_info; @@ -454,7 +454,8 @@ - (Status)predict:(const std::unordered_map&)inputs namespace onnxruntime { namespace coreml { -Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, int64_t* block_size, int64_t* stride) { +Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, + int64_t& num_blocks, int64_t& block_size, int64_t& stride) { const auto* shape = array.shape; const auto rank = shape.count; @@ -466,13 +467,13 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, i int64_t this_stride = [array.strides[rank - i] longLongValue]; if (this_stride != total_elems) { // non-contiguous if we have to move more than batch_elems for each entry - if (*block_size != 0) { + if (block_size != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Multiple non-contiguous dimensions in MLMultiArray are not supported."); } - *block_size = data_elems; - *stride = this_stride; + block_size = data_elems; + stride = this_stride; } const auto elems_this_dim = [shape[rank - i] longLongValue]; @@ -480,17 +481,17 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t* num_blocks, i total_elems = elems_this_dim * this_stride; } - if (*block_size == 0) { + if (block_size == 0) { // contiguous - *block_size = data_elems; - *stride = array_total_elements; + block_size = data_elems; + stride = array_total_elements; } - *num_blocks = data_elems / *block_size; + num_blocks = data_elems / block_size; ORT_ENFORCE(array_total_elements == total_elems, "Logic error calculating copy info"); - ORT_ENFORCE(*stride >= *block_size, "Logic error calculating copy info"); - ORT_ENFORCE(*stride * *num_blocks == total_elems, "Logic error calculating copy info"); + ORT_ENFORCE(stride >= block_size, "Logic error calculating copy info"); + ORT_ENFORCE(stride * num_blocks == total_elems, "Logic error calculating copy info"); return Status::OK(); } diff --git a/onnxruntime/test/providers/coreml/utils_test.mm b/onnxruntime/test/providers/coreml/utils_test.mm index 2ea3d229c0d8b..8e0fc779467f7 100644 --- a/onnxruntime/test/providers/coreml/utils_test.mm +++ b/onnxruntime/test/providers/coreml/utils_test.mm @@ -19,7 +19,7 @@ auto ValidateGetInfo(MLMultiArray* array, int64_t num_blocks = 0; int64_t block_size = 0; int64_t stride = 0; - auto status = coreml::GetMLMultiArrayCopyInfo(array, &num_blocks, &block_size, &stride); + auto status = coreml::GetMLMultiArrayCopyInfo(array, num_blocks, block_size, stride); if (!expect_valid) { ASSERT_STATUS_NOT_OK(status); From 37e95ec6fbd83a8b314b1ae01fbd6868e6ecb933 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 20 Jul 2024 17:20:42 +1000 Subject: [PATCH 8/9] Exclude objective-c test on non-apple platform builds. --- cmake/onnxruntime_unittests.cmake | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 38ed0b1640192..0c1e5e93c6844 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -679,7 +679,10 @@ if(onnxruntime_USE_RKNPU) endif() if(onnxruntime_USE_COREML) - list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*.cc) + if(APPLE) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*.mm) + endif() list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) From efe57ee9c50486eca93455b3055ea64bac7427c1 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 23 Jul 2024 14:31:04 +1000 Subject: [PATCH 9/9] Address PR comments --- .../builders/impl/convtranspose_op_builder.cc | 4 ++-- onnxruntime/core/providers/coreml/model/model.h | 14 +++++++------- onnxruntime/core/providers/coreml/model/model.mm | 5 +++-- onnxruntime/test/providers/coreml/utils_test.mm | 4 +--- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc index eca996597f76d..5b6d9d72ab3c9 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/convtranspose_op_builder.cc @@ -30,8 +30,8 @@ Status ConvTransposeOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuild const logging::Logger& /*logger*/) const { #if defined(COREML_ENABLE_MLPROGRAM) using namespace CoreML::Specification::MILSpec; // NOLINT - const auto& input_defs = node.InputDefs(); - const auto& output_defs = node.OutputDefs(); + const auto input_defs = node.InputDefs(); + const auto output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); NodeAttrHelper helper(node); diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 5ea0521d973fa..75b9aaf2185c9 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -13,12 +13,8 @@ #include "core/common/status.h" #include "core/platform/ort_mutex.h" -#if defined(__APPLE__) -#ifdef __OBJC__ +#if defined(__OBJC__) @class MLMultiArray; -#else -typedef struct objc_object MLMultiArray; -#endif #endif namespace onnxruntime { @@ -40,8 +36,12 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; -#if defined(__APPLE__) -// helper function that we unit test +#if defined(__OBJC__) +// helper function that we unit test. +// Handles an MLMultiArray that is contiguous, or has one non-contiguous dimension. +// The output values can be used to copy the array data to a contiguous buffer. +// Loop num_blocks times, copying block_size elements each time, moving stride elements between copies. +// A contiguous array will have num_blocks == 1, block_size == total_size (i.e. can be copied in a single operation) Status GetMLMultiArrayCopyInfo(const MLMultiArray* array, int64_t& num_blocks, int64_t& block_size, int64_t& stride); #endif diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 1bae8d63baf77..4fd822f0d0d15 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -466,7 +466,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, for (unsigned long i = 1; i <= rank; i++) { int64_t this_stride = [array.strides[rank - i] longLongValue]; if (this_stride != total_elems) { - // non-contiguous if we have to move more than batch_elems for each entry + // non-contiguous if (block_size != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Multiple non-contiguous dimensions in MLMultiArray are not supported."); @@ -482,9 +482,10 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, } if (block_size == 0) { - // contiguous + // all data is contiguous block_size = data_elems; stride = array_total_elements; + assert(block_size == stride); } num_blocks = data_elems / block_size; diff --git a/onnxruntime/test/providers/coreml/utils_test.mm b/onnxruntime/test/providers/coreml/utils_test.mm index 8e0fc779467f7..f55f108494e3e 100644 --- a/onnxruntime/test/providers/coreml/utils_test.mm +++ b/onnxruntime/test/providers/coreml/utils_test.mm @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if defined(__APPLE__) #import #include "gtest/gtest.h" @@ -69,7 +68,7 @@ auto ValidateGetInfo(MLMultiArray* array, ValidateGetInfo(array, 8, 8, 16, true); } - // dim -3 with non-contiguous data. 16 elements in the innder dimensions but stride is 24. + // dim -3 with non-contiguous data. 16 elements in the inner dimensions but stride is 24. { NSArray* shape = @[ @1, @2, @4, @4 ]; NSArray* strides = @[ @48, @24, @4, @1 ]; @@ -107,4 +106,3 @@ auto ValidateGetInfo(MLMultiArray* array, } } // namespace test } // namespace onnxruntime -#endif