Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Remove workaround for scalar #21704

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ Status Model::Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
return Status::OK();
}

bool Model::IsScalarOutput(const std::string& output_name) const {
return Contains(scalar_outputs_, output_name);
}

const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const {
return input_output_info_.at(name);
}
Expand Down
8 changes: 0 additions & 8 deletions onnxruntime/core/providers/webnn/builders/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class Model {
onnxruntime::common::Status Predict(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
const InlinedHashMap<std::string, OnnxTensorData>& outputs);

bool IsScalarOutput(const std::string& output_name) const;

// Mutex for exclusive lock to this model object.
OrtMutex& GetMutex() { return mutex_; }

Expand Down Expand Up @@ -65,8 +63,6 @@ class Model {
emscripten::val wnn_inputs_ = emscripten::val::object();
emscripten::val wnn_outputs_ = emscripten::val::object();

InlinedHashSet<std::string> scalar_outputs_;

std::vector<std::string> inputs_;
std::vector<std::string> outputs_;

Expand All @@ -83,10 +79,6 @@ class Model {
input_output_info_ = std::move(input_output_info);
}

void SetScalarOutputs(InlinedHashSet<std::string>&& scalar_outputs) {
scalar_outputs_ = std::move(scalar_outputs);
}

void AllocateInputOutputBuffers();
};

Expand Down
22 changes: 5 additions & 17 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,15 @@ Status ModelBuilder::RegisterInitializers() {
emscripten::val operand = emscripten::val::object();
if (IsSupportedDataType(data_type, webnn_supported_data_types)) {
ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type");
auto num_elements = SafeInt<size_t>(Product(tensor.dims()));
auto num_elements = SafeInt<size_t>(Product(shape));
emscripten::val view = emscripten::val::undefined();
std::byte* tensor_ptr = nullptr;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
} else {
std::vector<uint8_t> unpacked_tensor;
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
fdwr marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
Expand Down Expand Up @@ -187,16 +189,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
ORT_RETURN_IF(shape_proto == nullptr,
"shape_proto cannot be null for ", input_output_type, ": ", name);
const auto& shape = shape_proto->dim();
if (shape.empty()) {
// If we have an empty shape, this is a scalar input.
dims.push_back(1);

// We need to change the shapes of these scalar outputs back to {}
// when WebNN EP returns these values to ORT.
if (!is_input) {
AddScalarOutput(name);
}
} else {
if (!shape.empty()) {
dims.reserve(shape.size());
for (const auto& dim : shape) {
// dim_param free dimensions should have already been excluded by IsInputSupported().
Expand Down Expand Up @@ -343,7 +336,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
model->SetInputOutputInfo(std::move(input_output_info_));
// Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews
// for inputs and outputs because they will be transferred after compute() done.
Expand All @@ -352,10 +344,6 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
return Status::OK();
}

void ModelBuilder::AddScalarOutput(const std::string& output_name) {
scalar_outputs_.insert(output_name);
}

void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& operand) {
wnn_operands_.insert(std::make_pair(name, operand));
}
Expand Down
5 changes: 1 addition & 4 deletions onnxruntime/core/providers/webnn/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
InlinedHashMap<std::string, emscripten::val> wnn_operands_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
std::vector<std::vector<uint8_t>> unpacked_tensors_;

Check warning on line 72 in onnxruntime/core/providers/webnn/builders/model_builder.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/model_builder.h:72: Add #include <vector> for vector<> [build/include_what_you_use] [4]

InlinedHashSet<std::string> scalar_outputs_;
InlinedHashMap<std::string, OnnxTensorInfo> input_output_info_;

InlinedHashSet<std::string> skipped_initializers_;
Expand All @@ -92,9 +92,6 @@
Status RegisterModelOutputs() ORT_MUST_USE_RESULT;
Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input) ORT_MUST_USE_RESULT;

// Record the onnx scalar output names.
void AddScalarOutput(const std::string& output_name);

static const IOpBuilder* GetOpBuilder(const Node& node);
};

Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
auto input_tensor = ctx.GetInput(input_idx);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
auto shape = tensor_info.GetShape();
// If we have an empty shape, this is a scalar input,
// Since all the input output of WebNN EP is MultiArray, we will make the scalar input as a {1} MultiArray.
if (shape.empty())
shape.push_back(1);
const void* inputBuffer = const_cast<void*>(input_tensor.GetTensorRawData());
inputs.emplace(
input_name,
Expand All @@ -297,12 +293,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector<FusedNodeAndGra
const auto& output_info = model->GetInputOutputInfo(output_name);
auto output_shape = output_info.shape;
auto output_type = output_info.data_type;

// Since WebNN EP use {1} tensor as scalar, if the model output should have empty shape.
// We are going to replace the {1} shape of the output back to {}.
if (model->IsScalarOutput(output_name))
output_shape.clear();

auto output_tensor =
ctx.GetOutput(i, output_shape.data(), output_shape.size());

Expand Down
Loading