From ca710687c32d55a4d90d7de81a8f022288e2b59c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 14 Sep 2023 09:07:13 +0800 Subject: [PATCH 01/10] [js/webgpu] Support where Supported type: float. int32_t, uint32_t, bool. Case where_broadcast.jsonc is not enabled due to https://github.com/microsoft/onnxruntime/issues/17405. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/common.ts | 37 ++++ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 111 +++++++++++ js/web/test/data/ops/where.jsonc | 172 ++++++++++++++++++ js/web/test/data/ops/where_broadcast.jsonc | 84 +++++++++ js/web/test/suite-test-list.jsonc | 5 +- .../providers/js/js_execution_provider.cc | 6 + .../core/providers/js/operators/where.cc | 41 +++++ 9 files changed, 458 insertions(+), 1 deletion(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/where.ts create mode 100644 js/web/test/data/ops/where.jsonc create mode 100644 js/web/test/data/ops/where_broadcast.jsonc create mode 100644 onnxruntime/core/providers/js/operators/where.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..8da03c3cb967a 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -93,3 +93,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..658e8442c3614 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -24,6 +24,7 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; +import {where} from './ops/where'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -110,4 +111,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0ab777bfbdee9..2df59f8ae467d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -524,6 +524,43 @@ export const outputVariable = (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shape, false, components); +/** + * A helper class for generating WGSL code for manipulating broadcast indices for a shader's input. + */ +export interface BroadcastHelper { + /** + * WGSL code for getting offset from broadcast indices. + * + */ + broadcastIndicesToOffset(): string; +} + +class BroadcastHelperImpl implements BroadcastHelper { + constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {} + + broadcastIndicesToOffset(): string { + let implementation = ''; + for (let j = 0; j < this.inputs.length; j++) { + const dims = this.inputs[j].shape; + const name = this.inputs[j].name.substring(0, 1).toUpperCase(); + const strides = ShapeUtil.computeStrides(dims); + const offsets: string[] = []; + for (let i = dims.length - 1; i >= 0; i--) { + const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length); + offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); + } + implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + } + `; + } + return implementation; + } +} + +export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper => + new BroadcastHelperImpl(inputs, output); + /** * A ShaderHelper is a helper class for generating WGSL code. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..4d5a3db3d3498 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +const createWhereOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, + typeOutput: number) => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + const broadcastImpl = isBroadcast ? createBroadcastHelper([a, b, c], output).broadcastIndicesToOffset() : ''; + + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!isBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = broadcastIndicesToOffsetA(outputIndices${x}); + let offsetB${x} = broadcastIndicesToOffsetB(outputIndices${x}); + let offsetC${x} = broadcastIndicesToOffsetC(outputIndices${x}); + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let indexC${x} = offsetC${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); + `; + }; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + ${broadcastImpl} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputDataType = inputs[1].dataType; + + const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); + let outputShape = dimsA; + let outputSize = ShapeUtil.size(dimsA); + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); + if (!calculatedShape) { + throw new Error('Can\'t perform where op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (isBroadcast ? 1 : 4) /* vec size */)}) + }; +}; + +const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => { + const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default]; + const metadata: ProgramMetadata = {name, inputTypes}; + return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)}; +}; + +export const where = (context: ComputeContext): void => { + context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where')); +}; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..ad97177bb101b --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,84 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + }, + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 1] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, false, true], + "dims": [3, 1], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..bec48b3ac2ad7 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1386,7 +1386,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 72e36a161e9aa..fa462323192ef 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -229,6 +229,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -585,6 +588,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc new file mode 100644 index 0000000000000..2f8f5e275aa98 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Where, Where) +REG_ELEMENTWISE_VERSIONED_KERNEL(Where, 9, 15, Where); +REG_ELEMENTWISE_KERNEL(Where, 16, Where); +} // namespace js +} // namespace onnxruntime From 2a7863097b74671305efc7e85b9a68a8ea2138a9 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 21 Sep 2023 16:45:44 +0800 Subject: [PATCH 02/10] Fix dispatch number --- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 4d5a3db3d3498..2813e5769a005 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -96,7 +96,7 @@ const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly Te getShaderSource: (shaderHelper) => createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (isBroadcast ? 1 : 4) /* vec size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}) }; }; From 9b7efd2c9ac6f4117f6254dbf569b7b8d4d6ef19 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Mon, 25 Sep 2023 08:50:33 +0800 Subject: [PATCH 03/10] Fix npm run format --- js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 658e8442c3614..3a3090213016e 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -24,8 +24,8 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; -import {where} from './ops/where'; import * as unaryOps from './ops/unary-op'; +import {where} from './ops/where'; import {ComputeContext} from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 2813e5769a005..6917c09caddc3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -5,6 +5,7 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; const createWhereOpProgramShader = From 85f8a9d579d40e3e0185dccc5f3d46024fd8098d Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 26 Sep 2023 13:34:13 +0800 Subject: [PATCH 04/10] Remove BroadcastHelper --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 81 +++++++++++------------ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 16 ++--- 2 files changed, 46 insertions(+), 51 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 2df59f8ae467d..59210ef5e16cc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -102,6 +102,15 @@ export interface IndicesHelper { */ readonly indicesToOffset: (varIndices: string) => string; + /** + * WGSL code of an `u32` expression for getting original offset from broadcast indices. + * + * @param varIndices - a `type.indices` expression representing the output indices. + * + * @returns an `u32` expression + */ + readonly broadcastIndicesToOffset: (varIndices: string) => string; + /** * WGSL code of generating an indices literal * @@ -248,8 +257,8 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shape: readonly number[], isInput: boolean, - components: 1|2|3|4): IndicesHelper => { + (name: string, tensorType: number, shape: readonly number[], isInput: boolean, components: 1|2|3|4, + output?: IndicesHelper): IndicesHelper => { const rank = shape.length; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; const mappedType = getWgslMappedType(tensorType, components); @@ -262,6 +271,7 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, + broadcastIndicesToOffset: false, set: false, setByIndices: false, get: false, @@ -293,7 +303,7 @@ const createIndicesHelper = return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; }; - const offsets: string[] = []; + let offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { offsets.push(`${strides[i]}u * (indices[${i}])`); @@ -310,6 +320,25 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; + let broadcastIndicesToOffsetImplementation = ''; + // Currently output is only used when there is broadcasting. + if (output) { + offsets = []; + for (let i = shape.length - 1; i >= 0; i--) { + const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); + offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); + } + broadcastIndicesToOffsetImplementation = + `fn broadcastIndicesToOffset${name}(outputIndices: ${output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + }`; + } + + const broadcastIndicesToOffset = (varIndices: string) => { + implementationUsed.broadcastIndicesToOffset = true; + return `broadcastIndicesToOffset${name}(${varIndices});`; + }; + const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -462,6 +491,9 @@ const createIndicesHelper = if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } + if (implementationUsed.broadcastIndicesToOffset) { + impls.push(broadcastIndicesToOffsetImplementation); + } if (implementationUsed.set) { impls.push(setImplementation); } @@ -482,6 +514,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, + broadcastIndicesToOffset, indices, indicesGet, indicesSet, @@ -505,11 +538,12 @@ const createIndicesHelper = * @param type - the tensor type of the input. * @param shape - the tensor shape of the input. * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1. + * @param output - output IndicesHelper, used when there is broadcasting. * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => - createIndicesHelper(name, type, shape, true, components); + (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1, output?: IndicesHelper): + IndicesHelper => createIndicesHelper(name, type, shape, true, components, output); /** * Create a IndicesHelper for an output. @@ -524,43 +558,6 @@ export const outputVariable = (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shape, false, components); -/** - * A helper class for generating WGSL code for manipulating broadcast indices for a shader's input. - */ -export interface BroadcastHelper { - /** - * WGSL code for getting offset from broadcast indices. - * - */ - broadcastIndicesToOffset(): string; -} - -class BroadcastHelperImpl implements BroadcastHelper { - constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {} - - broadcastIndicesToOffset(): string { - let implementation = ''; - for (let j = 0; j < this.inputs.length; j++) { - const dims = this.inputs[j].shape; - const name = this.inputs[j].name.substring(0, 1).toUpperCase(); - const strides = ShapeUtil.computeStrides(dims); - const offsets: string[] = []; - for (let i = dims.length - 1; i >= 0; i--) { - const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length); - offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); - } - implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 { - return ${offsets.length > 0 ? offsets.join('+') : '0u'}; - } - `; - } - return implementation; - } -} - -export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper => - new BroadcastHelperImpl(inputs, output); - /** * A ShaderHelper is a helper class for generating WGSL code. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 6917c09caddc3..745b3f8cdb20b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper} from './common'; const createWhereOpProgramShader = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, @@ -15,10 +15,9 @@ const createWhereOpProgramShader = const vecSize = Math.ceil(outputSize / 4); const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); - const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); - const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); - const broadcastImpl = isBroadcast ? createBroadcastHelper([a, b, c], output).broadcastIndicesToOffset() : ''; + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4, output); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4, output); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4, output); let assignment: string; const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; @@ -34,9 +33,9 @@ const createWhereOpProgramShader = const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = broadcastIndicesToOffsetA(outputIndices${x}); - let offsetB${x} = broadcastIndicesToOffsetB(outputIndices${x}); - let offsetC${x} = broadcastIndicesToOffsetC(outputIndices${x}); + let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`)}; + let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`)}; + let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let indexC${x} = offsetC${x} / 4u; @@ -65,7 +64,6 @@ const createWhereOpProgramShader = return ` ${shaderHelper.declareVariables(c, a, b, output)} - ${broadcastImpl} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} ${assignment} From 7ad3181cf42eb3890bb2a2e228a58c70907d9a5e Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 26 Sep 2023 14:32:50 +0800 Subject: [PATCH 05/10] Remove duplicated comma --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 59210ef5e16cc..af152a742fa3d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -336,7 +336,7 @@ const createIndicesHelper = const broadcastIndicesToOffset = (varIndices: string) => { implementationUsed.broadcastIndicesToOffset = true; - return `broadcastIndicesToOffset${name}(${varIndices});`; + return `broadcastIndicesToOffset${name}(${varIndices})`; }; const indices = (...init: ReadonlyArray) => From 392d0528e81e9514c06d369149f8b2572ea73a6a Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Sep 2023 08:38:08 +0800 Subject: [PATCH 06/10] Remove output --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 20 ++++++++------------ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 12 ++++++------ 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index af152a742fa3d..ac58904025dcf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -106,10 +106,11 @@ export interface IndicesHelper { * WGSL code of an `u32` expression for getting original offset from broadcast indices. * * @param varIndices - a `type.indices` expression representing the output indices. + * @param output - output IndicesHelper. * * @returns an `u32` expression */ - readonly broadcastIndicesToOffset: (varIndices: string) => string; + readonly broadcastIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; /** * WGSL code of generating an indices literal @@ -257,8 +258,8 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shape: readonly number[], isInput: boolean, components: 1|2|3|4, - output?: IndicesHelper): IndicesHelper => { + (name: string, tensorType: number, shape: readonly number[], isInput: boolean, + components: 1|2|3|4): IndicesHelper => { const rank = shape.length; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; const mappedType = getWgslMappedType(tensorType, components); @@ -321,9 +322,8 @@ const createIndicesHelper = }; let broadcastIndicesToOffsetImplementation = ''; - // Currently output is only used when there is broadcasting. - if (output) { - offsets = []; + const broadcastIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + const offsets = []; for (let i = shape.length - 1; i >= 0; i--) { const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); @@ -332,9 +332,6 @@ const createIndicesHelper = `fn broadcastIndicesToOffset${name}(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - } - - const broadcastIndicesToOffset = (varIndices: string) => { implementationUsed.broadcastIndicesToOffset = true; return `broadcastIndicesToOffset${name}(${varIndices})`; }; @@ -538,12 +535,11 @@ const createIndicesHelper = * @param type - the tensor type of the input. * @param shape - the tensor shape of the input. * @param components - the number of components of the input. available values are 1, 2, 3, 4. default is 1. - * @param output - output IndicesHelper, used when there is broadcasting. * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1, output?: IndicesHelper): - IndicesHelper => createIndicesHelper(name, type, shape, true, components, output); + (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shape, true, components); /** * Create a IndicesHelper for an output. diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 745b3f8cdb20b..5543254e903c7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -15,9 +15,9 @@ const createWhereOpProgramShader = const vecSize = Math.ceil(outputSize / 4); const output = outputVariable('outputData', typeOutput, dimsOutput, 4); - const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4, output); - const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4, output); - const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4, output); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); let assignment: string; const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; @@ -33,9 +33,9 @@ const createWhereOpProgramShader = const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`)}; - let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`)}; - let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`)}; + let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let indexC${x} = offsetC${x} / 4u; From 830582ae9dd6769b63c921aacc7c8afd306488d2 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Sep 2023 08:57:47 +0800 Subject: [PATCH 07/10] Add output name to func name --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index ac58904025dcf..49e9217a2d41f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -329,11 +329,11 @@ const createIndicesHelper = offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); } broadcastIndicesToOffsetImplementation = - `fn broadcastIndicesToOffset${name}(outputIndices: ${output.type.indices}) -> u32 { + `fn ${output.name}broadcastIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; implementationUsed.broadcastIndicesToOffset = true; - return `broadcastIndicesToOffset${name}(${varIndices})`; + return `${output.name}broadcastIndicesTo${name}Offset(${varIndices})`; }; const indices = (...init: ReadonlyArray) => From c7c31c887d007092b2a4d09d64790c005553105b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Sep 2023 09:10:50 +0800 Subject: [PATCH 08/10] broadcast to broadcasted --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 24 +++++++++++------------ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 49e9217a2d41f..729ee78f282b2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -110,7 +110,7 @@ export interface IndicesHelper { * * @returns an `u32` expression */ - readonly broadcastIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; + readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; /** * WGSL code of generating an indices literal @@ -272,7 +272,7 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, - broadcastIndicesToOffset: false, + broadcastedIndicesToOffset: false, set: false, setByIndices: false, get: false, @@ -304,7 +304,7 @@ const createIndicesHelper = return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; }; - let offsets: string[] = []; + const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { offsets.push(`${strides[i]}u * (indices[${i}])`); @@ -321,19 +321,19 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; - let broadcastIndicesToOffsetImplementation = ''; - const broadcastIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + let broadcastedIndicesToOffsetImplementation = ''; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { const offsets = []; for (let i = shape.length - 1; i >= 0; i--) { const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); } - broadcastIndicesToOffsetImplementation = - `fn ${output.name}broadcastIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { + broadcastedIndicesToOffsetImplementation = + `fn ${output.name}broadcastedIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - implementationUsed.broadcastIndicesToOffset = true; - return `${output.name}broadcastIndicesTo${name}Offset(${varIndices})`; + implementationUsed.broadcastedIndicesToOffset = true; + return `${output.name}broadcastedIndicesTo${name}Offset(${varIndices})`; }; const indices = (...init: ReadonlyArray) => @@ -488,8 +488,8 @@ const createIndicesHelper = if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } - if (implementationUsed.broadcastIndicesToOffset) { - impls.push(broadcastIndicesToOffsetImplementation); + if (implementationUsed.broadcastedIndicesToOffset) { + impls.push(broadcastedIndicesToOffsetImplementation); } if (implementationUsed.set) { impls.push(setImplementation); @@ -511,7 +511,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, - broadcastIndicesToOffset, + broadcastedIndicesToOffset, indices, indicesGet, indicesSet, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 5543254e903c7..4c595bb90b4bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -33,9 +33,9 @@ const createWhereOpProgramShader = const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`, output)}; - let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`, output)}; - let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`, output)}; + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let indexC${x} = offsetC${x} / 4u; From f87aeb9c713a68e2b9245a82794d798c20923d1c Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Sep 2023 09:15:53 +0800 Subject: [PATCH 09/10] broadcast to broadcasted in comment --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 729ee78f282b2..8bd7099e090d4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -103,7 +103,7 @@ export interface IndicesHelper { readonly indicesToOffset: (varIndices: string) => string; /** - * WGSL code of an `u32` expression for getting original offset from broadcast indices. + * WGSL code of an `u32` expression for getting original offset from broadcasted indices. * * @param varIndices - a `type.indices` expression representing the output indices. * @param output - output IndicesHelper. From 480a5b8c051ce49b4cdf2b4a9433611658fa952d Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 28 Sep 2023 10:54:35 +0800 Subject: [PATCH 10/10] Use map for broadcastedIndicesToOffsetImplementation --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 8bd7099e090d4..fb800d66b59a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -321,19 +321,24 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; - let broadcastedIndicesToOffsetImplementation = ''; + const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {}; const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + implementationUsed.broadcastedIndicesToOffset = true; + const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; + if (implKey in broadcastedIndicesToOffsetImplementation) { + return `${implKey}(${varIndices})`; + } const offsets = []; for (let i = shape.length - 1; i >= 0; i--) { const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); } - broadcastedIndicesToOffsetImplementation = - `fn ${output.name}broadcastedIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { + broadcastedIndicesToOffsetImplementation[implKey] = + `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - implementationUsed.broadcastedIndicesToOffset = true; - return `${output.name}broadcastedIndicesTo${name}Offset(${varIndices})`; + + return `${implKey}(${varIndices})`; }; const indices = (...init: ReadonlyArray) => @@ -489,7 +494,7 @@ const createIndicesHelper = impls.push(indicesToOffsetImplementation); } if (implementationUsed.broadcastedIndicesToOffset) { - impls.push(broadcastedIndicesToOffsetImplementation); + Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); } if (implementationUsed.set) { impls.push(setImplementation);