Skip to content

Commit

Permalink
CoreML: Add GridSample ML Program support (#21431)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add GridSample ML Program support

One combination of inputs has diffs between the pytorch generated unit
tests data and CoreML. Disabling until needed as investigation may take
a while.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
High priorities models
  • Loading branch information
skottmckay authored Jul 24, 2024
1 parent 86cedc6 commit 1df9aa2
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/builders/impl/base_op_builder.h"
#include "core/providers/coreml/builders/impl/builder_utils.h"
#include "core/providers/coreml/builders/model_builder.h"
#include "core/providers/coreml/builders/op_builder_factory.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/providers/shared/utils/utils.h"

namespace onnxruntime {
namespace coreml {

namespace {
std::string_view GetMode(const NodeAttrHelper& helper) {
// opset 16 used bilinear, nearest, bicubic
// opset 20+ uses linear, nearest, cubic
// bilinear is what CoreML uses, so prefer that
// bicubic/cubic isn't supported

const auto& mode = helper.Get("mode", "linear");
if (mode == "linear") {
return "bilinear";
}

return mode;
}
} // namespace

class GridSampleOpBuilder : public BaseOpBuilder {
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override;

bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const override;

bool SupportsMLProgram() const override { return true; }
};

Status GridSampleOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model_builder,
[[maybe_unused]] const Node& node,
[[maybe_unused]] const logging::Logger& logger) const {
#if defined(COREML_ENABLE_MLPROGRAM)
using namespace CoreML::Specification::MILSpec; // NOLINT
// https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.image_resizing.resample

const auto input_defs = node.InputDefs();
const auto output_defs = node.OutputDefs();

NodeAttrHelper helper(node);
std::string mode{GetMode(helper)}; // need a std::string for use in AddScalarConstant
std::string padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);
const std::string coordinates_mode = "normalized_minus_one_to_one";

// adjust to coreml equivalents
if (padding_mode == "zeros") {
padding_mode = "constant";
}

auto op = model_builder.CreateOperation(node, "resample");
AddOperationInput(*op, "x", input_defs[0]->Name());
AddOperationInput(*op, "coordinates", input_defs[1]->Name());
AddOperationInput(*op, "sampling_mode", model_builder.AddScalarConstant(op->type(), "sampling_mode", mode));
AddOperationInput(*op, "padding_mode", model_builder.AddScalarConstant(op->type(), "padding_mode", padding_mode));
AddOperationInput(*op, "padding_value", model_builder.AddScalarConstant(op->type(), "padding_value", 0.0f));
AddOperationInput(*op, "coordinates_mode",
model_builder.AddScalarConstant(op->type(), "coordinates_mode", coordinates_mode));
AddOperationInput(*op, "align_corners", model_builder.AddScalarConstant(op->type(), "align_corners", align_corners));

AddOperationOutput(*op, *output_defs[0]);

model_builder.AddOperation(std::move(op));
#endif
return Status::OK();
}

bool GridSampleOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (!input_params.create_mlprogram) {
LOGS(logger, VERBOSE) << "GridSample is not supported.";
return false;
}

const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "GridSample: failed to get input shape";
return false;
}

const auto input_rank = input_shape.size();
if (input_rank != 4) {
LOGS(logger, VERBOSE) << "GridSample only supports 4D input. Got:" << input_rank << "D";
return false;
}

NodeAttrHelper helper(node);
std::string_view mode = GetMode(helper);

if (mode != "bilinear" && mode != "zeros") {
LOGS(logger, VERBOSE) << "GridSample does not support mode of " << mode;
return false;
}

// there is one combination of settings where the unit test fails.
// The ORT unit test values are generated by pytorch so not clear if it's an issue with CoreML.
// CoreML output is consistent for CPU and non-CPU at least.
// Disabling until there's a use-case that requires this combination.
const auto& padding_mode = helper.Get("padding_mode", "zeros");
const bool align_corners = helper.Get("align_corners", 0);

if (mode == "bilinear" && padding_mode == "reflection" && align_corners == false) {
LOGS(logger, VERBOSE) << "GridSample does not support mode:" << mode << " padding_mode:" << padding_mode
<< " align_corners:" << align_corners
<< " currently due to output diffs that need to be investigated";
return false;
}

return true;
}

void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<GridSampleOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
}

} // namespace coreml
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {
CreateSplitOpBuilder("Split", op_registrations);
}

CreateGridSampleOpBuilder("GridSample", op_registrations);

return op_registrations;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void CreateDepthToSpaceOpBuilder(const std::string& op_type, OpBuilderRegistrati
void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateGridSampleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreateLRNOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Loading

0 comments on commit 1df9aa2

Please sign in to comment.