Skip to content

Commit

Permalink
Updating MLBuffer specification
Browse files Browse the repository at this point in the history
* CPU devices now support MLBuffer
* MLContext.createBuffer now returns an Promise
  • Loading branch information
egalli committed Aug 1, 2024
1 parent 25e3a0f commit d4e53ea
Show file tree
Hide file tree
Showing 12 changed files with 34 additions and 30 deletions.
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ export class WebNNBackend {
this.bufferManager.releaseBufferId(bufferId);
}

public ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]): MLBuffer {
public async ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]):
Promise<MLBuffer> {
let dataType: MLOperandDataType;
if (typeof onnxDataType === 'number') {
const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ export const init =
// jsepReleaseBufferId,
(bufferId: number) => backend.releaseBufferId(bufferId),
// jsepEnsureBuffer
(bufferId: number, onnxDataType: number, dimensions: number[]) =>
async (bufferId: number, onnxDataType: number, dimensions: number[]) =>
backend.ensureBuffer(bufferId, onnxDataType, dimensions),
// jsepUploadBuffer
(bufferId: number, data: Uint8Array) => {
Expand Down
8 changes: 4 additions & 4 deletions js/web/lib/wasm/jsep/webnn/buffer-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export interface BufferManager {
/**
* Ensure a MLBuffer is created for the BufferId.
*/
ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer;
ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer>;
/**
* Upload data to a MLBuffer.
*/
Expand Down Expand Up @@ -85,12 +85,12 @@ class BufferTracker {
this.mlBuffer = undefined;
}

public ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): MLBuffer {
public async ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer> {
if (this.mlBuffer) {
return this.mlBuffer;
}

const buffer = this.context.createBuffer({dataType, dimensions});
const buffer = await this.context.createBuffer({dataType, dimensions});
this.mlBuffer = buffer;

if (this.activeUpload) {
Expand Down Expand Up @@ -151,7 +151,7 @@ class BufferManagerImpl implements BufferManager {
}
}

public ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer {
public async ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise<MLBuffer> {
const buffer = this.buffersById.get(bufferId);
if (!buffer) {
throw new Error('Buffer not found.');
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ interface MLBuffer {

type MLNamedBuffers = Record<string, MLBuffer>;
interface MLContext {
createBuffer(descriptor: MLOperandDescriptor): MLBuffer;
createBuffer(descriptor: MLOperandDescriptor): Promise<MLBuffer>;
writeBuffer(
dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number,
srcElementSize?: number): void;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ export const run = async(

// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
// ensureBuffer to get/create the MLBuffer.
const mlBuffer = ensureBuffer(dataOffset, dataType, dims);
const mlBuffer = await ensureBuffer(dataOffset, dataType, dims);

// do not release the tensor right now. it will be released when user calls tensor.dispose().
keepOutputTensor = true;
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ export declare namespace JSEP {
type ReplayFunction = () => void;
type ReserveBufferIdFunction = () => number;
type ReleaseBufferIdFunction = (bufferId: number) => void;
type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer;
type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) =>
Promise<MLBuffer>;
type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void;
type DownloadBufferFunction = (bufferId: number) => Promise<ArrayBuffer>;

Expand Down Expand Up @@ -154,7 +155,7 @@ export declare namespace JSEP {
* @param bufferId - specify the MLBuffer ID.
* @returns the MLBuffer.
*/
jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer;
jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => Promise<MLBuffer>;
/**
* [exported from pre-jsep.js] Upload data to MLBuffer.
* @param bufferId - specify the MLBuffer ID.
Expand Down
7 changes: 3 additions & 4 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ export class ModelTestContext {
const executionProviderConfig =
modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || {name: 'webnn'}) : modelTest.backend!;
let mlContext: MLContext|undefined;
if(['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) {

if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) {
const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption;
const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads;
Expand Down Expand Up @@ -593,7 +592,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty

const dataType = type === 'bool' ? 'uint8' : type;

const mlBuffer = mlContext.createBuffer({dataType, dimensions: dims as number[]});
const mlBuffer = await mlContext.createBuffer({dataType, dimensions: dims as number[]});

return ort.Tensor.fromMLBuffer(mlBuffer, {
dataType: type,
Expand All @@ -611,7 +610,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type;
const mlBuffer = mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]});
const mlBuffer = await mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]});
mlContext.writeBuffer(mlBuffer, cpuTensor.data);
return ort.Tensor.fromMLBuffer(
mlBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => mlBuffer.destroy()});
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,9 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) {
}
}

bool IsMLBufferSupported(WebnnDeviceType device_type) {
bool IsMLBufferSupported() {
static bool is_supported = !emscripten::val::global("MLBuffer").isUndefined();
// The current MLBuffer implementation only supports GPU and NPU devices.
return is_supported && device_type != WebnnDeviceType::CPU;
return is_supported;
}

} // namespace webnn
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type);

bool IsMLBufferSupported(WebnnDeviceType device_type);
bool IsMLBufferSupported();

} // namespace webnn
} // namespace onnxruntime
20 changes: 12 additions & 8 deletions onnxruntime/core/providers/webnn/builders/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,31 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
onnxruntime::common::Status Model::Dispatch(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
const InlinedHashMap<std::string, OnnxTensorData>& outputs) {
auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer");
for (const auto& input : inputs) {
const std::string& name = input.first;
const struct OnnxTensorData tensor = input.second;
auto promises = emscripten::val::array();
for (const auto& [_, tensor] : inputs) {
emscripten::val shape = emscripten::val::array();
for (const auto& dim : tensor.tensor_info.shape) {
uint32_t dim_val = SafeInt<uint32_t>(dim);
shape.call<void>("push", dim_val);
}
auto buffer = jsepEnsureBuffer(reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape);
wnn_inputs_.set(name, buffer);
promises.call<void>("push", buffer);
}
for (const auto& output : outputs) {
const std::string& name = output.first;
const struct OnnxTensorData tensor = output.second;
for (const auto& [_, tensor] : outputs) {
emscripten::val shape = emscripten::val::array();
for (const auto& dim : tensor.tensor_info.shape) {
uint32_t dim_val = SafeInt<uint32_t>(dim);
shape.call<void>("push", dim_val);
}
auto buffer = jsepEnsureBuffer(reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape);
wnn_outputs_.set(name, buffer);
promises.call<void>("push", buffer);
}
auto buffers = emscripten::val::global("Promise").call<emscripten::val>("all", promises).await();
for (const auto& [name, _] : inputs) {
wnn_inputs_.set(name, buffers.call<emscripten::val>("shift"));
}
for (const auto& [name, _] : outputs) {
wnn_outputs_.set(name, buffers.call<emscripten::val>("shift"));
}
wnn_context_.call<void>("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ Status ModelBuilder::Compile(std::unique_ptr<Model>& model) {
}
// Explicitly release the WebNN builder to free memory.
wnn_builder_ = emscripten::val::undefined();
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported(wnn_device_type_)));
model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported()));
model->SetInputs(std::move(input_names_));
model->SetOutputs(std::move(output_names_));
model->SetScalarOutputs(std::move(scalar_outputs_));
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
onnxruntime::kWebNNExecutionProvider,
// If MLBuffer is supported, we force all the tensors to be allocated as MLBuffer.
OrtDevice(
webnn::IsMLBufferSupported(webnn::DeviceTypeFromString(webnn_device_flags)) ? OrtDevice::GPU : OrtDevice::CPU,
webnn::IsMLBufferSupported() ? OrtDevice::GPU : OrtDevice::CPU,
OrtDevice::MemType::DEFAULT,
0)},
wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) {
Expand Down Expand Up @@ -381,14 +381,14 @@ WebNNExecutionProvider::GetKernelRegistry() const {
}

std::unique_ptr<onnxruntime::IDataTransfer> WebNNExecutionProvider::GetDataTransfer() const {
if (!webnn::IsMLBufferSupported(wnn_device_type_)) {
if (!webnn::IsMLBufferSupported()) {
return nullptr;
}
return std::make_unique<webnn::DataTransfer>();
}

std::vector<AllocatorPtr> WebNNExecutionProvider::CreatePreferredAllocators() {

Check warning on line 390 in onnxruntime/core/providers/webnn/webnn_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/webnn_execution_provider.cc:390: Add #include <vector> for vector<> [build/include_what_you_use] [4]
if (!webnn::IsMLBufferSupported(wnn_device_type_)) {
if (!webnn::IsMLBufferSupported()) {
return {};
}
AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) {
Expand Down

0 comments on commit d4e53ea

Please sign in to comment.