Skip to content

Commit

Permalink
[WebNN] Improve data type check of slice op (microsoft#22988)
Browse files Browse the repository at this point in the history
A follow-up of [[WebNN] Support negative steps for
slice](microsoft#22871 (comment)).
Slice op is emulated by reverse+slice when steps < 0 so
`SliceOpBuilder::HasSupportedInputsImpl()` should also check the
supported data types of reverse.

---------

Co-authored-by: Wanming Lin <[email protected]>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent 7ba6ed6 commit c06a5c2
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Helper functions, thanks for DML EP's OperatorHelper.
Expand Down Expand Up @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}

bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
Expand Down Expand Up @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten:
return false;
} else if (recognized_operator_type == RecognizedOperatorType::Pairwise) {
// Map to WebNN's gemm or matmul
return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger);
} else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) {
return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger);
} else {
return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde

// Operator support related.

bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial
return true;
}

bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /*
return true;
}

bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}

bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
Expand Down

0 comments on commit c06a5c2

Please sign in to comment.