Skip to content

Commit

Permalink
Updating MLBuffer API
Browse files Browse the repository at this point in the history
  • Loading branch information
egalli committed Aug 22, 2024
1 parent ed2f366 commit f30218c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
4 changes: 3 additions & 1 deletion js/web/lib/wasm/jsep/webnn/buffer-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ class BufferTracker {
}
}
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`);
const buffer = await this.context.createBuffer({ dataType, dimensions });
// eslint-disable-next-line no-bitwise
const usage = MLBufferUsage.READ_FROM | MLBufferUsage.WRITE_TO;
const buffer = await this.context.createBuffer({ dataType, dimensions, usage });
this.bufferEntry = [buffer, dataType, dimensions];
this.bufferCache.push(this.bufferEntry);

Expand Down
18 changes: 17 additions & 1 deletion js/web/lib/wasm/jsep/webnn/webnn.d.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

/* eslint-disable @typescript-eslint/naming-convention */

interface NavigatorML {
readonly ml: ML;
}
Expand Down Expand Up @@ -386,11 +388,25 @@ interface MLBuffer {
}

type MLNamedBuffers = Record<string, MLBuffer>;

type MLBufferUsageFlags = number;

declare const MLBufferUsage: {
readonly WEBGPU_INTEROP: MLBufferUsageFlags;
readonly READ_FROM: MLBufferUsageFlags;
readonly WRITE_TO: MLBufferUsageFlags;
};

interface MLBufferDescriptor extends MLOperandDescriptor {
usage: MLBufferUsageFlags;
}

interface MLContext {
createBuffer(descriptor: MLOperandDescriptor): Promise<MLBuffer>;
createBuffer(descriptor: MLBufferDescriptor): Promise<MLBuffer>;
writeBuffer(
dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number,
srcElementSize?: number): void;
readBuffer(srcBuffer: MLBuffer): Promise<ArrayBuffer>;
readBuffer(srcBuffer: MLBuffer, dstBuffer: ArrayBuffer): Promise<undefined>;
dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void;
}
12 changes: 10 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,11 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty

const dataType = type === 'bool' ? 'uint8' : type;

const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: dims as number[] });
const mlBuffer = await mlContext.createBuffer({
dataType,
dimensions: dims as number[],
usage: MLBufferUsage.READ_FROM,
});

return ort.Tensor.fromMLBuffer(mlBuffer, {
dataType: type,
Expand All @@ -679,7 +683,11 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso
throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`);
}
const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type;
const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: cpuTensor.dims as number[] });
const mlBuffer = await mlContext.createBuffer({
dataType,
dimensions: cpuTensor.dims as number[],
usage: MLBufferUsage.WRITE_TO,
});
mlContext.writeBuffer(mlBuffer, cpuTensor.data);
return ort.Tensor.fromMLBuffer(mlBuffer, {
dataType: cpuTensor.type,
Expand Down

0 comments on commit f30218c

Please sign in to comment.