diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 13888fa855ef6..916dec4545af3 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -157,7 +157,7 @@ class TensorIdTracker { // eslint-disable-next-line no-bitwise const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; - this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage); + this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); if (copyOld && this.activeUpload) { this.wrapper.write(this.activeUpload); @@ -306,6 +306,8 @@ class TensorManagerImpl implements TensorManager { dataType: MLOperandDataType, shape: readonly number[], usage: MLTensorUsageFlags, + writable: boolean, + readable: boolean, ): Promise { const sessionId = this.backend.currentSessionId; for (const [index, tensor] of this.freeTensors.entries()) { @@ -322,6 +324,8 @@ class TensorManagerImpl implements TensorManager { shape, dimensions: shape, usage, + writable, + readable, }); return new TensorWrapper({ sessionId, context, tensor, dataType, shape }); } diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 3505772cd2b73..a2d4e9af23e44 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -392,6 +392,7 @@ type MLNamedTensor = Record; type MLTensorUsageFlags = number; +// TODO(@Honry): Remove this once it is deprecated in Chromium. declare const MLTensorUsage: { readonly WEBGPU_INTEROP: MLTensorUsageFlags; readonly READ: MLTensorUsageFlags; @@ -400,6 +401,9 @@ declare const MLTensorUsage: { interface MLTensorDescriptor extends MLOperandDescriptor { usage: MLTensorUsageFlags; + importableToWebGPU?: boolean; + readable?: boolean; + writable?: boolean; } interface MLContext { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index aa62c8dc22c40..c37c10c781400 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -662,6 +662,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty // Assign both shape and dimensions while transitioning to new API. dimensions: dims as number[], usage: MLTensorUsage.READ, + readable: true, }); return ort.Tensor.fromMLTensor(mlTensor, { @@ -686,6 +687,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso // Assign both shape and dimensions while transitioning to new API. dimensions: cpuTensor.dims as number[], usage: MLTensorUsage.WRITE, + writable: true, }); mlContext.writeTensor(mlTensor, cpuTensor.data); return ort.Tensor.fromMLTensor(mlTensor, {