From 0bd59d6e3f625d342078c612bbdfc49e3c4c9483 Mon Sep 17 00:00:00 2001 From: rachguo Date: Fri, 17 Nov 2023 16:30:43 -0800 Subject: [PATCH] address pr comments --- .../core/providers/coreml/builders/impl/softmax_op_builder.cc | 4 ++-- .../core/providers/coreml/builders/impl/split_op_builder.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index d6ea249ee115e..dafc20ca0b95a 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -107,12 +107,12 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetStaticShape(*input_defs[0], input_shape, logger)) return false; const TensorShape shape(input_shape); if (shape.Size() == 0) { - LOGS(logger, VERBOSE) << "Cases that input data being empty due to a dimension with value of 0 is not supported"; + LOGS(logger, VERBOSE) << "Empty input data is not supported."; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index f6ae89a574c77..9448e94699e4e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -81,9 +81,9 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, num_outputs = node.OutputDefs().size(); coreml_splitnd->set_numsplits(num_outputs); } else { - num_outputs = SafeInt(helper.Get("num_outputs", -1)); + num_outputs = SafeInt(helper.Get("num_outputs", 2)); auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = narrow(std::ceil(float(split_dim_size) / num_outputs)); + uint64_t chunk_size = SafeInt((split_dim_size + num_outputs - 1) / num_outputs); uint64_t remainder = split_dim_size % chunk_size; if (remainder) { // uneven