From 8406185904b15920d875d287e4b14647c71e5356 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 14 Sep 2023 09:07:13 +0800 Subject: [PATCH] [js/webgpu] Support where Supported type: float. int32_t, uint32_t, bool. Case where_broadcast.jsonc is not enabled due to https://github.com/microsoft/onnxruntime/issues/17405. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../wasm/jsep/webgpu/ops/binary-like-util.ts | 120 ++++++++++++ js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 175 ++++-------------- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 37 ++++ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 94 ++++++++++ js/web/test/data/ops/where.jsonc | 172 +++++++++++++++++ js/web/test/data/ops/where_broadcast.jsonc | 56 ++++++ js/web/test/suite-test-list.jsonc | 5 +- .../providers/js/js_execution_provider.cc | 5 + .../core/providers/js/operators/where.cc | 41 ++++ 11 files changed, 568 insertions(+), 140 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/binary-like-util.ts create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/where.ts create mode 100644 js/web/test/data/ops/where.jsonc create mode 100644 js/web/test/data/ops/where_broadcast.jsonc create mode 100644 onnxruntime/core/providers/js/operators/where.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index de53f943bc9ef..10f2c797d2342 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -91,3 +91,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 9c46b97694903..5d16f88c0cb9e 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -23,6 +23,7 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; +import {where} from './ops/where'; import * as unaryOps from './ops/unary-op'; import {ComputeContext} from './types'; @@ -108,4 +109,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-like-util.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-like-util.ts new file mode 100644 index 0000000000000..53cdf2d72f76d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-like-util.ts @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; + +type BuiltinFunctionName = string; +export type BinaryCustomExpression = (expressionA: string, expressionB: string, expressionC?: string) => string; +export type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ + scalar: BinaryCustomExpression; + vector: BinaryCustomExpression; +}; + +type CreateOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], vectorize: boolean, + doBroadcast: boolean, funcCall: BinaryFunctionCall, typeOutput: number, additionalImplementation?: string) => + string; + +/* eslint-disable no-param-reassign */ +const createOpProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], funcCall: BinaryFunctionCall, + createOpProgramShader: CreateOpProgramShader, additionalImplementation?: string, + outputDataType?: number): ProgramInfo => { + const a = inputs.length === 3 ? inputs[1] : inputs[0]; + const b = inputs.length === 3 ? inputs[2] : inputs[1]; + if (outputDataType == null) { + outputDataType = inputs.length === 3 ? inputs[1].dataType : inputs[0].dataType; + } + + const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); + let outputShape = a.dims; + let outputSize = ShapeUtil.size(a.dims); + + let vectorize = false; + + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + if (!calculatedShape) { + throw new Error('Can\'t perform binary op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + + // check whether vectorize can be enabled + let sharedDimension = 1; + for (let i = 1; i < outputShape.length; i++) { + const dimA = a.dims[a.dims.length - i] ?? 1; + const dimB = b.dims[b.dims.length - i] ?? 1; + if (dimA === dimB) { + sharedDimension *= dimA; + } else { + break; + } + } + if (sharedDimension % 4 === 0) { + vectorize = true; + } + } else { + // element-wise + vectorize = true; + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => createOpProgramShader( + shaderHelper, inputs, outputShape, vectorize, isBroadcast, funcCall, outputDataType as number, + additionalImplementation), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => + ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + }; + }; + +// This is used for ops like binary, where. +export const createOpProgramInfoLoader = + (inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, + createOpProgramShader: CreateOpProgramShader, additionalImplementation?: string, cacheKey?: string, + outputDataType?: number): ProgramInfoLoader => { + const inputTypes = inputs.length === 3 ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default]; + const metadata: ProgramMetadata = {name, inputTypes, cacheHint: cacheKey}; + return { + ...metadata, + get: () => createOpProgramInfo( + metadata, inputs, funcCall, createOpProgramShader, additionalImplementation, outputDataType) + }; + }; + +export const getBroadcastIndexComponent = (name: string, x: number) => (` + let offset${name}${x} = broadcastIndicesToOffset${name}(outputIndices${x}); + let index${name}${x} = offset${name}${x} / 4u; + let component${name}${x} = offset${name}${x} % 4u; + `); + +type SingleAssignmentFuncCall = (resStr: string, x: number, typeCast?: string) => string; +export const fourAssignment = (singleAssignment: SingleAssignmentFuncCall, typeOutput: number) => { + let assignment = ''; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + return assignment; +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index b004ca37a2ea8..1a7b3da3c9af9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -3,22 +3,19 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; -import {BroadcastUtil, ShapeUtil} from '../../util'; -import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; - -import {inputVariable, outputVariable, ShaderHelper} from './common'; - -type BuiltinFunctionName = string; -type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; -type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ - scalar: BinaryCustomExpression; - vector: BinaryCustomExpression; -}; - -const createBinaryOpProgramShader = - (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], - vectorize: boolean, doBroadcast: boolean, funcCall: BinaryFunctionCall, typeA: number, typeB: number, - typeOutput: number, additionalImplementation?: string) => { +import {ShapeUtil} from '../../util'; +import {ComputeContext} from '../types'; + +import {BinaryCustomExpression, BinaryFunctionCall, createOpProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util'; +import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +const createOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], vectorize: boolean, + doBroadcast: boolean, funcCall: BinaryFunctionCall, typeOutput: number, additionalImplementation?: string) => { + const typeA = inputs[0].dataType; + const typeB = inputs[1].dataType; + const dimsA = inputs[0].dims; + const dimsB = inputs[1].dims; const outputSize = ShapeUtil.size(dimsOutput); const vecSize = Math.ceil(outputSize / 4); @@ -33,39 +30,18 @@ const createBinaryOpProgramShader = expressionVector = funcCall.vector; } - let broadcastImpl = ''; const output = outputVariable('outputData', typeOutput, dimsOutput, 4); const a = inputVariable('aData', typeA, dimsA, 4); const b = inputVariable('bData', typeB, dimsB, 4); - if (doBroadcast) { - const calcOffsetImpl = (dims: readonly number[]) => { - const strides = ShapeUtil.computeStrides(dims); - const offsets: string[] = []; - for (let i = dims.length - 1; i >= 0; i--) { - const idx = output.indicesGet('outputIndices', i + dimsOutput.length - dims.length); - offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); - } - return offsets.length > 0 ? offsets.join('+') : '0u'; - }; - - broadcastImpl = ` - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsA)}; - } - - fn calcOffsetB(outputIndices: ${output.type.indices}) -> u32 { - return ${calcOffsetImpl(dimsB)}; - } - `; - } + const broadcastImpl = doBroadcast ? createBroadcastHelper([a, b], output).broadcastIndicesToOffset() : ''; let assignment: string; if (vectorize) { if (doBroadcast) { assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; - let offsetA = calcOffsetA(outputIndices); - let offsetB = calcOffsetB(outputIndices); + let offsetA = broadcastIndicesToOffsetA(outputIndices); + let offsetB = broadcastIndicesToOffsetB(outputIndices); ${ output.setByOffset( 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} @@ -84,31 +60,12 @@ const createBinaryOpProgramShader = const expressionB = `bData[indexB${x}][componentB${x}]`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = calcOffsetA(outputIndices${x}); - let offsetB${x} = calcOffsetB(outputIndices${x}); - let indexA${x} = offsetA${x} / 4u; - let indexB${x} = offsetB${x} / 4u; - let componentA${x} = offsetA${x} % 4u; - let componentB${x} = offsetB${x} % 4u; + ${getBroadcastIndexComponent('A', x)} + ${getBroadcastIndexComponent('B', x)} ${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB)}); `; }; - if (typeOutput === DataType.bool) { - assignment = ` - var data = vec4(0); - ${singleAssignment('data', 0, 'u32')} - ${singleAssignment('data', 1, 'u32')} - ${singleAssignment('data', 2, 'u32')} - ${singleAssignment('data', 3, 'u32')} - outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; - } else { - assignment = ` - ${singleAssignment('outputData[global_idx]', 0)} - ${singleAssignment('outputData[global_idx]', 1)} - ${singleAssignment('outputData[global_idx]', 2)} - ${singleAssignment('outputData[global_idx]', 3)} - `; - } + assignment = fourAssignment(singleAssignment, typeOutput); } return ` @@ -123,91 +80,31 @@ const createBinaryOpProgramShader = }`; }; -const createBinaryOpProgramInfo = - (metadata: ProgramMetadata, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, - additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => { - const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); - let outputShape = a.dims; - let outputSize = ShapeUtil.size(a.dims); - - let vectorize = false; - - // TODO: deal with zero-sized tensors (eg. dims=[1,0]) - - if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); - if (!calculatedShape) { - throw new Error('Can\'t perform binary op on the given tensors'); - } - outputShape = calculatedShape; - outputSize = ShapeUtil.size(outputShape); - - // check whether vectorize can be enabled - let sharedDimension = 1; - for (let i = 1; i < outputShape.length; i++) { - const dimA = a.dims[a.dims.length - i] ?? 1; - const dimB = b.dims[b.dims.length - i] ?? 1; - if (dimA === dimB) { - sharedDimension *= dimA; - } else { - break; - } - } - if (sharedDimension % 4 === 0) { - vectorize = true; - } - } else { - // element-wise - vectorize = true; - } - - return { - ...metadata, - getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, - outputDataType, additionalImplementation), - outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => - ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) - }; - }; - -const createBinaryOpProgramInfoLoader = - (inputs: readonly TensorView[], name: string, funcCall: BinaryFunctionCall, additionalImplementation?: string, - cacheKey?: string, outputDataType?: number): ProgramInfoLoader => { - const metadata: - ProgramMetadata = {name, inputTypes: [GpuDataType.default, GpuDataType.default], cacheHint: cacheKey}; - return { - ...metadata, - get: () => createBinaryOpProgramInfo( - metadata, inputs[0], inputs[1], funcCall, additionalImplementation, outputDataType) - }; - }; - export const add = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`)); + context.compute(createOpProgramInfoLoader(context.inputs, 'Add', (a, b) => `${a}+${b}`, createOpProgramShader)); }; export const div = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`)); + context.compute(createOpProgramInfoLoader(context.inputs, 'Div', (a, b) => `${a}/${b}`, createOpProgramShader)); }; export const equal = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'Equal', ({scalar: (a, b) => `u32(${a}==${b})`, vector: (a, b) => `vec4(${a}==${b})`}), - undefined, undefined, DataType.bool)); + createOpProgramShader, undefined, undefined, DataType.bool)); }; export const mul = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`)); + context.compute(createOpProgramInfoLoader(context.inputs, 'Mul', (a, b) => `${a}*${b}`, createOpProgramShader)); }; export const pow = (context: ComputeContext): void => { const type = inputVariable('input', context.inputs[0].dataType, context.inputs[0].dims).type.value; const roundStr = type === 'i32' ? 'round' : ''; - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'Pow', ({scalar: (a, b) => `pow_custom(${a},${b})`, vector: (a, b) => `pow_vector_custom(${a},${b})`}), + createOpProgramShader, ` fn pow_custom(a : ${type}, b : ${type}) -> ${type} { if (b == ${type}(0.0)) { @@ -226,30 +123,30 @@ export const pow = (context: ComputeContext): void => { }; export const sub = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`)); + context.compute(createOpProgramInfoLoader(context.inputs, 'Sub', (a, b) => `${a}-${b}`, createOpProgramShader)); }; export const greater = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'Greater', ({scalar: (a, b) => `u32(${a}>${b})`, vector: (a, b) => `vec4(${a}>${b})`}), - undefined, undefined, DataType.bool)); + createOpProgramShader, undefined, undefined, DataType.bool)); }; export const less = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'Less', ({scalar: (a, b) => `u32(${a}<${b})`, vector: (a, b) => `vec4(${a}<${b})`}), - undefined, undefined, DataType.bool)); + createOpProgramShader, undefined, undefined, DataType.bool)); }; export const greaterOrEqual = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'GreaterOrEqual', - ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), undefined, undefined, - DataType.bool)); + ({scalar: (a, b) => `u32(${a}>=${b})`, vector: (a, b) => `vec4(${a}>=${b})`}), createOpProgramShader, + undefined, undefined, DataType.bool)); }; export const lessOrEqual = (context: ComputeContext): void => { - context.compute(createBinaryOpProgramInfoLoader( + context.compute(createOpProgramInfoLoader( context.inputs, 'LessOrEqual', ({scalar: (a, b) => `u32(${a}<=${b})`, vector: (a, b) => `vec4(${a}<=${b})`}), - undefined, undefined, DataType.bool)); + createOpProgramShader, undefined, undefined, DataType.bool)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index f3845e3110905..2ad9921907503 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -524,6 +524,43 @@ export const outputVariable = (name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper => createIndicesHelper(name, type, shape, false, components); +/** + * A helper class for generating WGSL code for manipulating broadcast indices for a shader's input. + */ +export interface BroadcastHelper { + /** + * WGSL code for getting offset from broadcast indices. + * + */ + broadcastIndicesToOffset(): string; +} + +class BroadcastHelperImpl implements BroadcastHelper { + constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {} + + broadcastIndicesToOffset(): string { + let implementation = ''; + for (let j = 0; j < this.inputs.length; j++) { + const dims = this.inputs[j].shape; + const name = this.inputs[j].name.substring(0, 1).toUpperCase(); + const strides = ShapeUtil.computeStrides(dims); + const offsets: string[] = []; + for (let i = dims.length - 1; i >= 0; i--) { + const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length); + offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`); + } + implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + } + `; + } + return implementation; + } +} + +export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper => + new BroadcastHelperImpl(inputs, output); + /** * A ShaderHelper is a helper class for generating WGSL code. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..6c6e975a49c4d --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {ComputeContext} from '../types'; + +import {BinaryCustomExpression, BinaryFunctionCall, createOpProgramInfoLoader, fourAssignment, getBroadcastIndexComponent} from './binary-like-util'; +import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +const createOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], vectorize: boolean, + doBroadcast: boolean, funcCall: BinaryFunctionCall, typeOutput: number, additionalImplementation?: string) => { + const typeA = inputs[1].dataType; + const typeB = inputs[2].dataType; + const typeC = inputs[0].dataType; + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + let expressionScalar: BinaryCustomExpression; + let expressionVector: BinaryCustomExpression; + if (typeof funcCall === 'string') { + expressionScalar = expressionVector = (a, b, c) => `${funcCall}((${a}),(${b}),(${c}))`; + } else if (typeof funcCall === 'function') { + expressionScalar = expressionVector = funcCall; + } else { + expressionScalar = funcCall.scalar; + expressionVector = funcCall.vector; + } + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', typeA, dimsA, 4); + const b = inputVariable('bData', typeB, dimsB, 4); + const c = inputVariable('cData', typeC, dimsC, 4); + + const broadcastImpl = doBroadcast ? createBroadcastHelper([a, b, c], output).broadcastIndicesToOffset() : ''; + + let assignment: string; + if (vectorize) { + if (doBroadcast) { + assignment = ` + let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; + let offsetA = broadcastIndicesToOffsetA(outputIndices); + let offsetB = broadcastIndicesToOffsetB(outputIndices); + ${ + output.setByOffset( + 'global_idx', + expressionVector( + a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u'), c.getByOffset('offsetC / 4u')))}`; + } else { + assignment = output.setByOffset( + 'global_idx', + expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } + } else { + if (!doBroadcast) { + throw new Error('no necessary to use scalar implementation for element-wise binary op implementation.'); + } + + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + ${getBroadcastIndexComponent('A', x)} + ${getBroadcastIndexComponent('B', x)} + ${getBroadcastIndexComponent('C', x)} + ${resStr}[${x}] = ${typeCast}(${expressionScalar(expressionA, expressionB, expressionC)}); + `; + }; + assignment = fourAssignment(singleAssignment, typeOutput); + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + + ${additionalImplementation ?? ''} + ${broadcastImpl} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +export const where = (context: ComputeContext): void => { + context.compute(createOpProgramInfoLoader( + context.inputs, 'Where', (a, b, c) => `select(${b}, ${a}, ${c})`, createOpProgramShader)); +}; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..4310e979cc334 --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,56 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index f4249b24101e5..f0fe49a37d25e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1378,7 +1378,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6b548921cdc8c..223b586470766 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -229,6 +229,9 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); @@ -576,6 +579,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), }; diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc new file mode 100644 index 0000000000000..2f8f5e275aa98 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Where, Where) +REG_ELEMENTWISE_VERSIONED_KERNEL(Where, 9, 15, Where); +REG_ELEMENTWISE_KERNEL(Where, 16, Where); +} // namespace js +} // namespace onnxruntime