diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..8da03c3cb967a 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -93,3 +93,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..658e8442c3614 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -24,6 +24,7 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; +import {where} from './ops/where'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -110,4 +111,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c054da51a3098..7b82dc769535f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -524,6 +524,43 @@ export const outputVariable = (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shape, false, components); +/** + * A helper class for generating WGSL code for manipulating broadcast indices for a shader's input. + */ +export interface BroadcastHelper { + /** + * WGSL code for getting offset from broadcast indices. + * + */ + broadcastIndicesToOffset(): string; +} + +class BroadcastHelperImpl implements BroadcastHelper { + constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {} + + broadcastIndicesToOffset(): string { + let implementation = ''; + for (let j = 0; j < this.inputs.length; j++) { + const dims = this.inputs[j].shape; + const name = this.inputs[j].name.substring(0, 1).toUpperCase(); + const strides = ShapeUtil.computeStrides(dims); + const offsets: string[] = []; + for (let i = dims.length - 1; i >= 0; i--) { + const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length); + offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); + } + implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + } + `; + } + return implementation; + } +} + +export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper => + new BroadcastHelperImpl(inputs, output); + /** * A ShaderHelper is a helper class for generating WGSL code. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..cbfa56bba72a2 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; +import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +const createWhereOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], doBroadcast: boolean, + typeOutput: number) => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + const broadcastImpl = doBroadcast ? createBroadcastHelper([a, b, c], output).broadcastIndicesToOffset() : ''; + + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!doBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = broadcastIndicesToOffsetA(outputIndices${x}); + let offsetB${x} = broadcastIndicesToOffsetB(outputIndices${x}); + let offsetC${x} = broadcastIndicesToOffsetC(outputIndices${x}); + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let indexC${x} = offsetC${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); + `; + }; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + ${broadcastImpl} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputDataType = inputs[1].dataType; + + const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); + let outputShape = dimsA; + let outputSize = ShapeUtil.size(dimsA); + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); + if (!calculatedShape) { + throw new Error('Can\'t perform where op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (isBroadcast ? 1 : 4) /* vec size */)}) + }; +}; + +const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => { + const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default]; + const metadata: ProgramMetadata = {name, inputTypes}; + return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)}; +}; + +export const where = (context: ComputeContext): void => { + context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where')); +}; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..ad97177bb101b --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,84 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + }, + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 1] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, false, true], + "dims": [3, 1], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..bec48b3ac2ad7 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1386,7 +1386,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0674fe02d093d..7f6e065001488 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -229,6 +229,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -585,6 +588,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc new file mode 100644 index 0000000000000..2f8f5e275aa98 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Where, Where) +REG_ELEMENTWISE_VERSIONED_KERNEL(Where, 9, 15, Where); +REG_ELEMENTWISE_KERNEL(Where, 16, Where); +} // namespace js +} // namespace onnxruntime