Skip to content

Commit

Permalink
[WebNN] Check split's output name (#22884)
Browse files Browse the repository at this point in the history
Chromium will rename split's output name from "output" to "outputs" in
`OpSupportLimits` to align with spec, the EP should check which name is
available to make it compatible.
  • Loading branch information
Honry authored Nov 19, 2024
1 parent 8a06f13 commit 5b78712
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class SplitOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -163,6 +165,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool SplitOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;

if (GetType(*output_defs[0], output_type, logger)) {
// Chromium has changed the output name of split from 'output' to 'outputs',
// to avoid breaking the existing API, we need to check both names.
std::string wnn_output_name = wnn_limits["split"]["output"].isUndefined() ? "outputs" : "output";
return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, wnn_output_name, "outputs", logger);
}

return false;
}

void CreateSplitOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.builders.push_back(std::make_unique<SplitOpBuilder>());
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());
Expand Down

0 comments on commit 5b78712

Please sign in to comment.