diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 3963b80de58a4..d035fd34bd072 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -47,8 +47,20 @@ enum COREMLFlags { // and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead. static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits"; static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat"; +// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes"; static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs"; +// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property +// Core ML segments the model’s compute graph and specializes each segment for the target compute device. +// This process can affect the model loading time and the prediction latency. +// Use this option to tailor the specialization strategy for your model. +static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy"; +// Profile the Core ML MLComputePlan. +// This logs the hardware each operator is dispatched to and the estimated execution time. +// Intended for developer usage but provide useful diagnostic information if performance is not as expected. +static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan"; +// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu +static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU"; #ifdef __cplusplus extern "C" { diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index cc68fa6ec399a..442194cb31cbc 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -151,7 +151,7 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu return false; } -#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 // To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type(); if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index f161b309a2425..d533b867bd454 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -133,9 +133,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu return false; } -#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) - // to pass https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1563483&view=logs&j=f7cc61a9-cc70-56e7-b06c-4668ca17e426 - // ReductionOpTest.ReduceSum_half_bert +#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64 + // skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros int32_t input_type; GetType(*input_defs[0], input_type, logger); if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index c8df7c1a43f65..a1b3a18265c70 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -13,6 +13,10 @@ #include "core/optimizer/initializer.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#ifdef __APPLE__ +#include +#endif + namespace onnxruntime { namespace coreml { @@ -54,32 +58,50 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const } } +#if defined(COREML_ENABLE_MLPROGRAM) +void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder, + const Node& node, const logging::Logger& logger) { + const auto& input_defs(node.InputDefs()); + TensorShapeVector axes; + GetAxes(model_builder, node, axes); + + std::vector input_shape; + GetShape(*input_defs[0], input_shape, logger); + auto op = model_builder.CreateOperation(node, "reshape"); + AddOperationInput(*op, "x", input_defs[0]->Name()); + TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); + AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); +} +#endif + Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, [[maybe_unused]] const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); - const auto& input_defs(node.InputDefs()); auto* coreml_squeeze = layer->mutable_squeeze(); TensorShapeVector axes; GetAxes(model_builder, node, axes); - std::vector input_shape; - GetShape(*input_defs[0], input_shape, logger); #if defined(COREML_ENABLE_MLPROGRAM) + const auto& input_defs(node.InputDefs()); if (model_builder.CreateMLProgram()) { using namespace CoreML::Specification::MILSpec; - std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "reshape"; +#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64 + // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input + if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) { + HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger); + return Status::OK(); + } +#endif + std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims"; std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); AddOperationInput(*op, "x", input_defs[0]->Name()); - if (coreml_op_type == "squeeze") { - if (!axes.empty()) { - // coreml squeeze op does support negative axes - AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); - } - } else { - TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes); - AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape))); + if (!axes.empty()) { + // coreml supports negative axes + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes))); } AddOperationOutput(*op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(op)); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 2a02c1f4124f6..6486942199df7 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -408,7 +408,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), - coreml_compute_unit_(coreml_options.ComputeUnits()), + coreml_options_(coreml_options), create_ml_program_(coreml_options.CreateMLProgram()), model_output_path_(GetModelOutputPath(create_ml_program_)), onnx_input_names_(std::move(onnx_input_names)), @@ -989,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { get_sanitized_io_info(std::move(input_output_info_)), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_compute_unit_); + logger_, coreml_options_); } else #endif { @@ -999,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { std::move(input_output_info_), std::move(scalar_outputs_), std::move(int64_outputs_), - logger_, coreml_compute_unit_); + logger_, coreml_options_); } return model->LoadModel(); // load using CoreML API, including compilation diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af47869f7e1c3..e19597cf0dc2e 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -7,6 +7,7 @@ #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/model/model.h" +#include "core/providers/coreml/coreml_options.h" #if defined(COREML_ENABLE_MLPROGRAM) // coremltools classes @@ -22,8 +23,6 @@ class StorageWriter; #endif namespace onnxruntime { -class CoreMLOptions; - namespace coreml { class IOpBuilder; @@ -218,7 +217,7 @@ class ModelBuilder { const GraphViewer& graph_viewer_; const logging::Logger& logger_; const int32_t coreml_version_; - const uint32_t coreml_compute_unit_; + CoreMLOptions coreml_options_; const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc index df78f74383871..4ec780208e528 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.cc +++ b/onnxruntime/core/providers/coreml/coreml_options.cc @@ -63,11 +63,14 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option {"MLProgram", COREML_FLAG_CREATE_MLPROGRAM}, {"NeuralNetwork", COREML_FLAG_USE_NONE}, }; - std::unordered_set valid_options = { + const std::unordered_set valid_options = { kCoremlProviderOption_MLComputeUnits, kCoremlProviderOption_ModelFormat, kCoremlProviderOption_RequireStaticInputShapes, kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU, }; // Validate the options for (const auto& option : options) { @@ -90,6 +93,16 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option require_static_shape_ = option.second == "1"; } else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) { enable_on_subgraph_ = option.second == "1"; + } else if (kCoremlProviderOption_SpecializationStrategy == option.first) { + if (option.second != "Default" && option.second != "FastPrediction") { + ORT_THROW("Invalid value for option ", option.first, ": ", option.second, + ". Valid values are Default and FastPrediction."); + } + strategy_ = option.second; + } else if (kCoremlProviderOption_ProfileComputePlan == option.first) { + profile_compute_plan_ = option.second == "1"; + } else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) { + allow_low_precision_accumulation_on_gpu_ = option.second == "1"; } } } diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h index 8bb748fcd69c9..fd05c96927bd1 100644 --- a/onnxruntime/core/providers/coreml/coreml_options.h +++ b/onnxruntime/core/providers/coreml/coreml_options.h @@ -14,6 +14,9 @@ class CoreMLOptions { bool create_mlprogram_{false}; bool enable_on_subgraph_{false}; uint32_t compute_units_{0}; + std::string strategy_; + bool profile_compute_plan_{false}; + bool allow_low_precision_accumulation_on_gpu_{false}; public: explicit CoreMLOptions(uint32_t coreml_flags); @@ -25,6 +28,9 @@ class CoreMLOptions { bool CreateMLProgram() const { return create_mlprogram_; } bool EnableOnSubgraph() const { return enable_on_subgraph_; } uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; } + bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; } + bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; } + bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; } private: void ValidateAndParseProviderOption(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 68ecbe5fb80c4..84b7d741b4714 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -18,6 +18,7 @@ #endif namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution; @@ -53,7 +54,7 @@ class Model { std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, - const logging::Logger& logger, uint32_t coreml_compute_unit); + const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index c8edb64ff55d7..755dbfbd6e68c 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -25,6 +25,7 @@ #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/objc_str_utils.h" #include "core/providers/coreml/shape_utils.h" +#include "core/providers/coreml/coreml_options.h" // force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need // to manually do this @@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, return Status::OK(); } +// since __clang_major__ >= 15, MLComputePlan is introduced in +// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`. +// The macro API_AVAILABLE should also be fine. +// Otherwise, the compiler will complain `MLComputePlan` is not defined. +// we define __clang_analyzer__ here is for bypass static analysis +void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) { +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (@available(macOS 14.4, iOS 17.4, *)) { + [MLComputePlan loadContentsOfURL:compileUrl + configuration:config + completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) { + if (!computePlan) { + NSLog(@"Error loading compute plan: %@", error); + // Handle error. + return; + } + MLModelStructureProgram* program = computePlan.modelStructure.program; + if (!program) { + NSLog(@"Error loading program from compute plan., this is not a mlprogram model"); + return; + } + + MLModelStructureProgramFunction* mainFunction = program.functions[@"main"]; + if (!mainFunction) { + NSLog(@"Error loading main function from program"); + return; + } + + NSArray* operations = mainFunction.block.operations; + NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count); + for (MLModelStructureProgramOperation* operation in operations) { + // Get the compute device usage for the operation. + MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation]; + id preferredDevice = computeDeviceUsage.preferredComputeDevice; + // Get the estimated cost of executing the operation. + MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation]; + if (![operation.operatorName isEqualToString:@"const"]) { + NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight); + } + } + }]; + } else { + NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API"); + } +#endif +} + // Internal Execution class // This class is part of the model class and handles the calls into CoreML. Specifically, it performs // 1. Compile the model by given path for execution @@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, // 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: - Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); + Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options); ~Execution(); Status LoadModel(); @@ -320,13 +368,13 @@ Status Predict(const std::unordered_map& inputs, NSString* coreml_model_path_{nil}; NSString* compiled_model_path_{nil}; const logging::Logger& logger_; - uint32_t coreml_compute_unit_{0}; + CoreMLOptions coreml_options_; MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit) +Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options) : logger_(logger), - coreml_compute_unit_(coreml_compute_unit) { + coreml_options_(coreml_options) { @autoreleasepool { coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); } @@ -395,17 +443,41 @@ Status Predict(const std::unordered_map& inputs, compiled_model_path_ = [compileUrl path]; MLModelConfiguration* config = [[MLModelConfiguration alloc] init]; - - if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) { + uint32_t coreml_compute_unit = coreml_options_.ComputeUnits(); + if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) { config.computeUnits = MLComputeUnitsCPUOnly; - } else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) { + } else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) { config.computeUnits = MLComputeUnitsCPUAndGPU; - } else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { + } else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) { config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine } else { config.computeUnits = MLComputeUnitsAll; } + if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) { + config.allowLowPrecisionAccumulationOnGPU = YES; + } + +// Set the specialization strategy to FastPrediction for macOS 10.15+ +// since __clang_major__ >= 15, optimizationHints is introduced in +// Same as above comments for why we are checking __clang_major__. +// we define __clang_analyzer__ here is for bypass static analysis +#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__) + if (HAS_COREML8_OR_LATER) { + MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init]; + if (coreml_options_.UseStrategy("FastPrediction")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction; + config.optimizationHints = optimizationHints; + } else if (coreml_options_.UseStrategy("Default")) { + optimizationHints.specializationStrategy = MLSpecializationStrategyDefault; + config.optimizationHints = optimizationHints; + } + } +#endif + if (coreml_options_.ProfileComputePlan()) { + ProfileComputePlan(compileUrl, config); + } + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; if (error != nil || model_ == nil) { @@ -524,8 +596,8 @@ Status Predict(const std::unordered_map& inputs, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, - uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)), + const CoreMLOptions& coreml_options) + : execution_(std::make_unique(path, logger, coreml_options)), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index c6f2e7401ea1e..e9036e2fc7e1a 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -4,6 +4,7 @@ #include "core/providers/coreml/model/model.h" namespace onnxruntime { +class CoreMLOptions; namespace coreml { class Execution {}; @@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, - uint32_t /*coreml_flags*/) + const CoreMLOptions& /*coreml_flags*/) : execution_(std::make_unique()), model_input_names_(std::move(model_input_names)), model_output_names_(std::move(model_output_names)), diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 3f2c2cb7f761c..23c3812ebd025 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -135,6 +135,9 @@ namespace perftest { "\t [CoreML only] [MLComputeUnits]:[CPUAndNeuralEngine CPUAndGPU ALL CPUOnly] Specify to limit the backend device used to run the model.\n" "\t [CoreML only] [AllowStaticInputShapes]:[0 1].\n" "\t [CoreML only] [EnableOnSubgraphs]:[0 1].\n" + "\t [CoreML only] [SpecializationStrategy]:[Default FastPrediction].\n" + "\t [CoreML only] [ProfileComputePlan]:[0 1].\n" + "\t [CoreML only] [AllowLowPrecisionAccumulationOnGPU]:[0 1].\n" "\t [Example] [For CoreML EP] -e coreml -i \"ModelFormat|MLProgram MLComputeUnits|CPUAndGPU\"\n" "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 5db1894a5074b..a96028ed3903e 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -346,7 +346,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); static const std::unordered_set available_keys = {kCoremlProviderOption_MLComputeUnits, kCoremlProviderOption_ModelFormat, kCoremlProviderOption_RequireStaticInputShapes, - kCoremlProviderOption_EnableOnSubgraphs}; + kCoremlProviderOption_EnableOnSubgraphs, + kCoremlProviderOption_SpecializationStrategy, + kCoremlProviderOption_ProfileComputePlan, + kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU}; ParseSessionConfigs(ov_string, provider_options, available_keys); std::unordered_map available_options = { @@ -364,6 +367,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); (provider_option.second == "1" || provider_option.second == "0")) { } else if (provider_option.first == kCoremlProviderOption_EnableOnSubgraphs && (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_SpecializationStrategy && + (provider_option.second == "Default" || provider_option.second == "FastPrediction")) { + } else if (provider_option.first == kCoremlProviderOption_ProfileComputePlan && + (provider_option.second == "0" || provider_option.second == "1")) { + } else if (provider_option.first == kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU && + (provider_option.second == "0" || provider_option.second == "1")) { } else { ORT_THROW("Invalid value for option ", provider_option.first, ": ", provider_option.second); }