Skip to content

Commit

Permalink
[js/webgpu] Support where
Browse files Browse the repository at this point in the history
Supported type: float. int32_t, uint32_t, bool.
Case where_broadcast.jsonc is not enabled due to microsoft#17405.
  • Loading branch information
axinging committed Sep 21, 2023
1 parent 4f3f436 commit cdb5a52
Show file tree
Hide file tree
Showing 9 changed files with 458 additions and 1 deletion.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ Do not modify directly.*
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
| Where | ai.onnx(9-15,16+) | |
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import {where} from './ops/where';
import * as unaryOps from './ops/unary-op';
import {ComputeContext} from './types';

Expand Down Expand Up @@ -110,4 +111,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
['Where', [where]],
]);
37 changes: 37 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,43 @@ export const outputVariable =
(name: string, type: number, shape: readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shape, false, components);

/**
* A helper class for generating WGSL code for manipulating broadcast indices for a shader's input.
*/
export interface BroadcastHelper {
/**
* WGSL code for getting offset from broadcast indices.
*
*/
broadcastIndicesToOffset(): string;
}

class BroadcastHelperImpl implements BroadcastHelper {
constructor(private inputs: IndicesHelper[], private output: IndicesHelper) {}

broadcastIndicesToOffset(): string {
let implementation = '';
for (let j = 0; j < this.inputs.length; j++) {
const dims = this.inputs[j].shape;
const name = this.inputs[j].name.substring(0, 1).toUpperCase();
const strides = ShapeUtil.computeStrides(dims);
const offsets: string[] = [];
for (let i = dims.length - 1; i >= 0; i--) {
const idx = this.output.indicesGet('outputIndices', i + this.output.shape.length - dims.length);
offsets.push(`${strides[i]}u * (${idx} % ${dims[i]}u)`);
}
implementation += `fn broadcastIndicesToOffset${name}(outputIndices: ${this.output.type.indices}) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}
`;
}
return implementation;
}
}

export const createBroadcastHelper = (inputs: IndicesHelper[], output: IndicesHelper): BroadcastHelper =>
new BroadcastHelperImpl(inputs, output);

/**
* A ShaderHelper is a helper class for generating WGSL code.
*/
Expand Down
111 changes: 111 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {createBroadcastHelper, inputVariable, outputVariable, ShaderHelper} from './common';

const createWhereOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
typeOutput: number) => {
const outputSize = ShapeUtil.size(dimsOutput);
const vecSize = Math.ceil(outputSize / 4);

const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
const broadcastImpl = isBroadcast ? createBroadcastHelper([a, b, c], output).broadcastIndicesToOffset() : '';

let assignment: string;
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
if (!isBroadcast) {
assignment = output.setByOffset(
'global_idx',
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
} else {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = broadcastIndicesToOffsetA(outputIndices${x});
let offsetB${x} = broadcastIndicesToOffsetB(outputIndices${x});
let offsetC${x} = broadcastIndicesToOffsetC(outputIndices${x});
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let indexC${x} = offsetC${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
if (typeOutput === DataType.bool) {
assignment = `
var data = vec4<u32>(0);
${singleAssignment('data', 0, 'u32')}
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
`;
}
}

return `
${shaderHelper.declareVariables(c, a, b, output)}
${broadcastImpl}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${assignment}
}`;
};

const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const dimsA = inputs[1].dims;
const dimsB = inputs[2].dims;
const dimsC = inputs[0].dims;
const outputDataType = inputs[1].dataType;

const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
let outputShape = dimsA;
let outputSize = ShapeUtil.size(dimsA);
// TODO: deal with zero-sized tensors (eg. dims=[1,0])

if (isBroadcast) {
const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false);
if (!calculatedShape) {
throw new Error('Can\'t perform where op on the given tensors');
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);
}

return {
...metadata,
getShaderSource: (shaderHelper) =>
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (isBroadcast ? 1 : 4) /* vec size */)})
};
};

const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => {
const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default];
const metadata: ProgramMetadata = {name, inputTypes};
return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)};
};

export const where = (context: ComputeContext): void => {
context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where'));
};
172 changes: 172 additions & 0 deletions js/web/test/data/ops/where.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
[
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] float32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0],
"dims": [8],
"type": "float32"
},
{
"data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] int32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4, 8, 7, 2, 4, 8, 7, 1],
"dims": [8],
"type": "int32"
},
{
"data": [1, 3, 9, 6, 1, 3, 9, 2],
"dims": [8],
"type": "int32"
}
],
"outputs": [
{
"data": [4, 3, 7, 6, 4, 3, 7, 2],
"dims": [8],
"type": "int32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] uint32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4, 8, 7, 2, 4, 8, 7, 1],
"dims": [8],
"type": "uint32"
},
{
"data": [1, 4294967295, 9, 6, 1, 3, 9, 2],
"dims": [8],
"type": "uint32"
}
],
"outputs": [
{
"data": [4, 4294967295, 7, 6, 4, 3, 7, 2],
"dims": [8],
"type": "uint32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] bool T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [true, true, true, true, true, true, true, true],
"dims": [8],
"type": "float32"
},
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3 3] T[3 3] T[1] float32 broadcast",
"inputs": [
{
"data": [true, true, true, true, true, false, false, false, false],
"dims": [3, 3],
"type": "bool"
},
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [3, 3],
"type": "float32"
},
{
"data": [-1.0],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 1, 2, 3, 4, -1, -1, -1, -1],
"dims": [3, 3],
"type": "float32"
}
]
}
]
}
]
Loading

0 comments on commit cdb5a52

Please sign in to comment.