Skip to content

Commit

Permalink
[WebNN EP] Re-implement Unsqueeze, Squeeze, Flatten with WebNN's resh…
Browse files Browse the repository at this point in the history
…ape (#18585)

WebNN will not provide `unsqueeze`, `squeeze`, `flatten2d` ops, as it
can be easily implemented by reshape.
  • Loading branch information
Honry authored Nov 29, 2023
1 parent 14a3434 commit 38b640c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 20 deletions.
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Erf", {"erf", false}},
{"Exp", {"exp", false}},
{"Expand", {"expand", false}},
{"Flatten", {"flattenTo2d", false}},
{"Flatten", {"reshape", true}},
{"Floor", {"floor", true}},
{"Gather", {"gather", false}},
{"Gemm", {"gemm", true}},
Expand Down Expand Up @@ -206,12 +206,12 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"Softmax", {"softmax", true}},
{"Split", {"split", true}},
{"Sqrt", {"sqrt", false}},
{"Squeeze", {"squeeze", false}},
{"Squeeze", {"reshape", true}},
{"Sub", {"sub", true}},
{"Tan", {"tan", false}},
{"Tanh", {"tanh", true}},
{"Transpose", {"transpose", true}},
{"Unsqueeze", {"unsqueeze", false}},
{"Unsqueeze", {"reshape", true}},
{"Where", {"elementwiseIf", false}},
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
int64_t rank = input_shape.size();
NodeAttrHelper helper(node);
int64_t axis = helper.Get("axis", 1);
ORT_ENFORCE(axis >= -rank && axis <= rank, "axis ", axis,
" is not in valid range [-", rank, ",", rank, "]");
if (axis < 0) {
axis += rank;
}
axis = HandleNegativeAxis(axis, rank);

// Use WebNN's reshape to implement Flatten.
int64_t num_pre_axis_elements = std::accumulate(
input_shape.begin(), input_shape.begin() + static_cast<int32_t>(axis), 1, std::multiplies<int64_t>());
int64_t num_post_axis_elements = std::accumulate(
input_shape.begin() + static_cast<int32_t>(axis), input_shape.end(), 1, std::multiplies<int64_t>());

std::vector<uint32_t> new_shape = {SafeInt<uint32_t>(num_pre_axis_elements),
SafeInt<uint32_t>(num_post_axis_elements)};

emscripten::val inputs = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("flattenTo2d", inputs,
static_cast<int32_t>(axis));
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", inputs, emscripten::val::array(new_shape));

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,42 +56,65 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil

emscripten::val options = emscripten::val::object();
std::vector<int32_t> axes_data;
auto rank = input_rank;

if (node.SinceVersion() >= 13 && input_defs.size() > 1) {
// Input axes is provided, use axes initializer data.
const auto& initializers = model_builder.GetInitializerTensors();
const auto& axes_tensor = *initializers.at(input_defs[1]->Name());
Initializer axes_initializer(axes_tensor);
const auto axes_data_span = axes_initializer.DataAsSpan<int64_t>();
const auto output_rank = input_rank + axes_data_span.size();
if (op_type == "Unsqueeze") {
// Unsqueeze should check the expanded rank.
rank = input_rank + axes_data_span.size();
}
std::transform(
axes_data_span.begin(), axes_data_span.end(), std::back_inserter(axes_data),
[output_rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, output_rank)); });
[rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, rank)); });
} else {
NodeAttrHelper helper(node);
if (helper.HasAttr("axes")) {
auto axes = helper.Get("axes", std::vector<int64_t>{});
const auto output_rank = input_rank + axes.size();
if (op_type == "Unsqueeze") {
// Unsqueeze should check the expanded rank.
rank = input_rank + axes.size();
}
std::transform(
axes.begin(), axes.end(), std::back_inserter(axes_data),
[output_rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, output_rank)); });
[rank](int64_t axis) -> int32_t { return SafeInt<int32_t>(HandleNegativeAxis(axis, rank)); });
}
}

if (axes_data.size() > 0) {
options.set("axes", emscripten::val::array(axes_data));
}

emscripten::val output = emscripten::val::undefined();
// Use WebNN's reshape to implement Squeeze/Unsqueeze.
std::vector<uint32_t> new_shape;
std::transform(
input_shape.begin(), input_shape.end(), std::back_inserter(new_shape),
[](int64_t data) -> uint32_t { return SafeInt<uint32_t>(data); });
// Sort axes_data in ascending order.
std::sort(axes_data.begin(), axes_data.end());
if (op_type == "Squeeze") {
output = model_builder.GetBuilder().call<emscripten::val>("squeeze", input, options);
if (!axes_data.empty()) {
for (auto axis = axes_data.rbegin(); axis != axes_data.rend(); ++axis) {
size_t index = *axis;
new_shape.erase(new_shape.begin() + index);
}
} else {
// Remove all the single dimensions.
new_shape.erase(
std::remove_if(new_shape.begin(), new_shape.end(), [](uint32_t axis) { return axis == 1; }), new_shape.end());
}
} else if (op_type == "Unsqueeze") {
output = model_builder.GetBuilder().call<emscripten::val>("unsqueeze", input, options);
// Expand new_shape according to axes_data.
for (const int32_t& axis : axes_data) {
new_shape.insert(new_shape.begin() + axis, 1);
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type);
}

output = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
Expand Down

0 comments on commit 38b640c

Please sign in to comment.