Skip to content

Commit

Permalink
[WebNN EP] Support external data (#22263)
Browse files Browse the repository at this point in the history
### Description
This PR introduces support for registering external data inside WebNN
EP.

### Motivation and Context

- The WebNN EP needs to register the initializers at graph compilation
stage, for initializers from external data, it can't leverage the
general external data loader framework because the graph compilation of
WebNN EP is executed before external data loader called.
- Exposes the `utils::GetExternalDataInfo`, it is useful for WebNN EP to
read the external tensor's infomation.
- Define a new `registerMLConstant` in JSEP to create WebNN constants
from external data in WebNN backend, with the info of tensor as
parameters, as well as the `Module.MountedFiles`, which holds all
preloaded external files.
  • Loading branch information
Honry authored Oct 23, 2024
1 parent ffaddea commit 33e2f6a
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 104 deletions.
63 changes: 63 additions & 0 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,69 @@ export class WebNNBackend {
return id;
}

// Register WebNN Constant operands from external data.
public registerMLConstant(
externalFilePath: string,
dataOffset: number,
dataLength: number,
builder: MLGraphBuilder,
desc: MLOperandDescriptor,
mountedFiles: Map<string, Uint8Array> | undefined,
): MLOperand {
// If available, "Module.MountedFiles" is a Map for all preloaded files.
if (!mountedFiles) {
throw new Error('External mounted files are not available.');
}

let filePath = externalFilePath;
if (externalFilePath.startsWith('./')) {
filePath = externalFilePath.substring(2);
}
const fileData = mountedFiles.get(filePath);
if (!fileData) {
throw new Error(`File with name ${filePath} not found in preloaded files.`);
}

if (dataOffset + dataLength > fileData.byteLength) {
throw new Error('Out of bounds: data offset and length exceed the external file data size.');
}

const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer;
let bufferView: ArrayBufferView;
switch (desc.dataType) {
case 'float32':
bufferView = new Float32Array(buffer);
break;
case 'float16':
bufferView = new Uint16Array(buffer);
break;
case 'int32':
bufferView = new Int32Array(buffer);
break;
case 'uint32':
bufferView = new Uint32Array(buffer);
break;
case 'int64':
bufferView = new BigInt64Array(buffer);
break;
case 'uint64':
bufferView = new BigUint64Array(buffer);
break;
case 'int8':
bufferView = new Int8Array(buffer);
break;
case 'uint8':
bufferView = new Uint8Array(buffer);
break;
default:
throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`);
}

LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`);

return builder.constant(desc, bufferView);
}

public flush(): void {
// Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations.
}
Expand Down
62 changes: 31 additions & 31 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2)
DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2)

static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,
std::basic_string<ORTCHAR_T>& external_file_path,
onnxruntime::FileOffsetType& file_offset,
SafeInt<size_t>& tensor_byte_size) {
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
"Tensor does not have external data to read from.");

ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
"External data type cannot be UNDEFINED or STRING.");

std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));

const auto& location = external_data_info->GetRelPath();

external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
: (tensor_proto_dir / location);

ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
const size_t external_data_length = external_data_info->GetLength();
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
"TensorProto: ", tensor_proto.name(),
" external data size mismatch. Computed size: ", *&tensor_byte_size,
", external_data.length: ", external_data_length);

file_offset = external_data_info->GetOffset();

return Status::OK();
}

// Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully.
// Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr
// then uses the current directory instead.
Expand Down Expand Up @@ -261,6 +230,37 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo

namespace utils {

Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,
std::basic_string<ORTCHAR_T>& external_file_path,
onnxruntime::FileOffsetType& file_offset,
SafeInt<size_t>& tensor_byte_size) {
ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto),
"Tensor does not have external data to read from.");

ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto),
"External data type cannot be UNDEFINED or STRING.");

std::unique_ptr<onnxruntime::ExternalDataInfo> external_data_info;
ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info));

const auto& location = external_data_info->GetRelPath();

external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location)
: (tensor_proto_dir / location);

ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size));
const size_t external_data_length = external_data_info->GetLength();
ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size,
"TensorProto: ", tensor_proto.name(),
" external data size mismatch. Computed size: ", *&tensor_byte_size,
", external_data.length: ", external_data_length);

file_offset = external_data_info->GetOffset();

return Status::OK();
}

void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) {
tensor_proto.set_raw_data(std::move(param));
}
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@

namespace onnxruntime {
namespace utils {
/**
* This function is used to get the external data info from the given tensor proto.
* @param tensor_proto given initializer tensor
* @param tensor_proto_dir directory of the tensor proto file
* @param external_file_path output external file path
* @param file_offset output tensor offset
* @param tensor_byte_size output tensor byte size
* @returns Status::OK() if the function is executed successfully
*/
Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto,
const std::filesystem::path& tensor_proto_dir,
std::basic_string<ORTCHAR_T>& external_file_path,
onnxruntime::FileOffsetType& file_offset,
SafeInt<size_t>& tensor_byte_size);
/**
* This function is used to convert the endianess of Tensor data.
* Mostly, will be used in big endian system to support the model file
Expand Down
25 changes: 0 additions & 25 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,6 @@

namespace onnxruntime {
namespace webnn {

// Shared functions.
bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node,
const logging::Logger& logger) {
for (const auto* node_arg : node.InputDefs()) {
const auto& input_name(node_arg->Name());
if (!Contains(initializers, input_name))
continue;

const auto& tensor = *initializers.at(input_name);
if (tensor.has_data_location() &&
tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
LOGS(logger, VERBOSE) << "Initializer [" << input_name
<< "] with external data location are not currently supported";
return true;
}
}

return false;
}

// Add operator related.

Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
Expand All @@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedOutputsImpl(node, wnn_limits, logger))
return false;

// We do not support external initializers for now.
if (HasExternalInitializer(initializers, node, logger))
return false;

if (!HasSupportedOpSet(node, logger))
return false;

Expand Down
113 changes: 65 additions & 48 deletions onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() {
auto num_elements = SafeInt<size_t>(Product(shape));
emscripten::val view = emscripten::val::undefined();
std::byte* tensor_ptr = nullptr;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));

if (utils::HasExternalData(tensor)) {
// Create WebNN Constant from external data.
std::basic_string<ORTCHAR_T> external_file_path;
onnxruntime::FileOffsetType data_offset;
SafeInt<size_t> tensor_byte_size;
ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(
tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size));

auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant");
operand = jsepRegisterMLConstant(emscripten::val(external_file_path),
static_cast<int32_t>(data_offset),
static_cast<int32_t>(tensor_byte_size),
wnn_builder_,
desc);
} else {
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint64_t*>(tensor_ptr))};
break;
default:
break;
if (tensor.has_raw_data()) {
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
} else {
// Store temporary unpacked_tensor.
unpacked_tensors_.push_back({});
std::vector<uint8_t>& unpacked_tensor = unpacked_tensors_.back();
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor));
tensor_ptr = reinterpret_cast<std::byte*>(unpacked_tensor.data());
}
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int8_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint16_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<float*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<int64_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint32_t*>(tensor_ptr))};
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
view = emscripten::val{emscripten::typed_memory_view(num_elements,
reinterpret_cast<uint64_t*>(tensor_ptr))};
break;
default:
break;
}

// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
// buffers in JS side. Simply create a copy to fix it.
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
}

// Wasm memory grow will cause all array buffers reallocation, which will be treated as detached
// buffers in JS side. Simply create a copy to fix it.
operand = wnn_builder_.call<emscripten::val>("constant", desc, view.call<emscripten::val>("slice"));
} else {
// TODO: support other type.
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/wasm/pre-jsep.js
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => {
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
return backend['registerMLTensor'](tensor, dataType, shape);
}

Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => {
return backend['registerMLConstant'](
externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles);
}
}
};

0 comments on commit 33e2f6a

Please sign in to comment.