Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rachguo authored and rachguo committed Nov 18, 2023
1 parent 1f1119f commit 0bd59d6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
if (!GetStaticShape(*input_defs[0], input_shape, logger))
return false;

const TensorShape shape(input_shape);
if (shape.Size() == 0) {
LOGS(logger, VERBOSE) << "Cases that input data being empty due to a dimension with value of 0 is not supported";
LOGS(logger, VERBOSE) << "Empty input data is not supported.";
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
num_outputs = node.OutputDefs().size();
coreml_splitnd->set_numsplits(num_outputs);
} else {
num_outputs = SafeInt<uint64_t>(helper.Get("num_outputs", -1));
num_outputs = SafeInt<uint64_t>(helper.Get("num_outputs", 2));
auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())];
uint64_t chunk_size = narrow<uint64_t>(std::ceil(float(split_dim_size) / num_outputs));
uint64_t chunk_size = SafeInt<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
uint64_t remainder = split_dim_size % chunk_size;
if (remainder) {
// uneven
Expand Down

0 comments on commit 0bd59d6

Please sign in to comment.