Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Support external data #22263

Merged
merged 3 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,69 @@ export class WebNNBackend {
return id;
}

// Register WebNN Constant operands from external data.
public registerMLConstant(
externalFilePath: string,
dataOffset: number,
dataLength: number,
builder: MLGraphBuilder,
desc: MLOperandDescriptor,
mountedFiles: Map<string, Uint8Array> | undefined,
): MLOperand {
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (!mountedFiles) {
throw new Error('External mounted files are not available.');
}

let filePath = externalFilePath;
if (externalFilePath.startsWith('./')) {
filePath = externalFilePath.substring(2);
}
const fileData = mountedFiles.get(filePath);
if (!fileData) {
throw new Error(`File with name ${filePath} not found in preloaded files.`);
}

if (dataOffset + dataLength > fileData.byteLength) {
throw new Error('Out of bounds: data offset and length exceed the external file data size.');
}

const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
let bufferView: ArrayBufferView;
switch (desc.dataType) {
case 'float32':
bufferView = new Float32Array(buffer);
break;
case 'float16':
bufferView = new Uint16Array(buffer);
Honry marked this conversation as resolved.
Show resolved Hide resolved
break;
case 'int32':
bufferView = new Int32Array(buffer);
break;
case 'uint32':
bufferView = new Uint32Array(buffer);
break;
case 'int64':
bufferView = new BigInt64Array(buffer);
break;
case 'uint64':
bufferView = new BigUint64Array(buffer);
break;
case 'int8':
bufferView = new Int8Array(buffer);
break;
case 'uint8':
bufferView = new Uint8Array(buffer);
break;
default:
throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
}

LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`);

return builder.constant(desc, bufferView);
}

public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
Expand Down
62 changes: 31 additions & 31 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,37 +165,6 @@
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)

static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,
std::basic_string<ORTCHAR_T>& external_file_path,
onnxruntime::FileOffsetType& file_offset,
SafeInt<size_t>& tensor_byte_size) {
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
"Tensor does not have external data to read from.");

ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
"External data type cannot be UNDEFINED or STRING.");

std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));

const auto& location = external_data_info->GetRelPath();

external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
: (tensor_proto_dir / location);

ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
const size_t external_data_length = external_data_info->GetLength();
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
"TensorProto: ", tensor_proto.name(),
" external data size mismatch. Computed size: ", *&tensor_byte_size,
", external_data.length: ", external_data_length);

file_offset = external_data_info->GetOffset();

return Status::OK();
}

// Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully.
// Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr
// then uses the current directory instead.
Expand Down Expand Up @@ -261,6 +230,37 @@

namespace utils {

Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,

Check warning on line 234 in onnxruntime/core/framework/tensorprotoutils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.cc:234: Do not indent within a namespace. [whitespace/indent_namespace] [4]
std::basic_string<ORTCHAR_T>& external_file_path,

Check warning on line 235 in onnxruntime/core/framework/tensorprotoutils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.cc:235: Do not indent within a namespace. [whitespace/indent_namespace] [4]
onnxruntime::FileOffsetType& file_offset,

Check warning on line 236 in onnxruntime/core/framework/tensorprotoutils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.cc:236: Do not indent within a namespace. [whitespace/indent_namespace] [4]
SafeInt<size_t>& tensor_byte_size) {

Check warning on line 237 in onnxruntime/core/framework/tensorprotoutils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.cc:237: Do not indent within a namespace. [whitespace/indent_namespace] [4]
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
"Tensor does not have external data to read from.");

ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
"External data type cannot be UNDEFINED or STRING.");

std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));

const auto& location = external_data_info->GetRelPath();

external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
: (tensor_proto_dir / location);

ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
const size_t external_data_length = external_data_info->GetLength();
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
"TensorProto: ", tensor_proto.name(),
" external data size mismatch. Computed size: ", *&tensor_byte_size,
", external_data.length: ", external_data_length);

file_offset = external_data_info->GetOffset();

return Status::OK();
}

void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
tensor_proto.set_raw_data(std::move(param));
}
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@

namespace onnxruntime {
namespace utils {
/**
* This function is used to get the external data info from the given tensor proto.
* @param tensor_proto given initializer tensor
* @param tensor_proto_dir directory of the tensor proto file
* @param external_file_path output external file path
* @param file_offset output tensor offset
* @param tensor_byte_size output tensor byte size
* @returns Status::OK() if the function is executed successfully
*/
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,

Check warning on line 36 in onnxruntime/core/framework/tensorprotoutils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.h:36: Do not indent within a namespace. [whitespace/indent_namespace] [4]
std::basic_string<ORTCHAR_T>& external_file_path,

Check warning on line 37 in onnxruntime/core/framework/tensorprotoutils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.h:37: Do not indent within a namespace. [whitespace/indent_namespace] [4]
onnxruntime::FileOffsetType& file_offset,

Check warning on line 38 in onnxruntime/core/framework/tensorprotoutils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.h:38: Do not indent within a namespace. [whitespace/indent_namespace] [4]
SafeInt<size_t>& tensor_byte_size);

Check warning on line 39 in onnxruntime/core/framework/tensorprotoutils.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: onnxruntime/core/framework/tensorprotoutils.h:39: Do not indent within a namespace. [whitespace/indent_namespace] [4]
/**
* This function is used to convert the endianess of Tensor data.
* Mostly, will be used in big endian system to support the model file
Expand Down
25 changes: 0 additions & 25 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,6 @@

namespace onnxruntime {
namespace webnn {

// Shared functions.
bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) {
for (const auto* node_arg : node.InputDefs()) {
const auto& input_name(node_arg->Name());
if (!Contains(initializers, input_name))
continue;

const auto& tensor = *initializers.at(input_name);
if (tensor.has_data_location() &&
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
LOGS(logger, VERBOSE) << "Initializer [" << input_name
<< "] with external data location are not currently supported";
return true;
}
}

return false;
}

// Add operator related.

Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
Expand All @@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;

// We do not support external initializers for now.
if (HasExternalInitializer(initializers, node, logger))
return false;

if (!HasSupportedOpSet(node, logger))
return false;

Expand Down
113 changes: 65 additions & 48 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() {
auto num_elements = SafeInt<size_t>(Product(shape));
emscripten::val view = emscripten::val::undefined();
std::byte* tensor_ptr = nullptr;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));

if (utils::HasExternalData(tensor)) {
// Create WebNN Constant from external data.
std::basic_string<ORTCHAR_T> external_file_path;
onnxruntime::FileOffsetType data_offset;
SafeInt<size_t> tensor_byte_size;
ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(
tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size));

auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant");
operand = jsepRegisterMLConstant(emscripten::val(external_file_path),
static_cast<int32_t>(data_offset),
static_cast<int32_t>(tensor_byte_size),
wnn_builder_,
desc);
} else {
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint64_t*>(tensor_ptr))};
break;
default:
break;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
} else {
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint64_t*>(tensor_ptr))};
break;
default:
break;
}

// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
Honry marked this conversation as resolved.
Show resolved Hide resolved
// buffers in JS side. Simply create a copy to fix it.
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
}

// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
// buffers in JS side. Simply create a copy to fix it.
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
} else {
// TODO: support other type.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => {
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
return backend['registerMLTensor'](tensor, dataType, shape);
}

Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => {
return backend['registerMLConstant'](
externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles);
}
}
};
Loading