Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Support where #17544

Merged
merged 13 commits into from
Oct 3, 2023
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,4 @@ Do not modify directly.*
| Tile | ai.onnx(6-12,13+) | |
| Transpose | ai.onnx(1-12,13+) | need perf optimization |
| Unsqueeze | ai.onnx(1-10,11-12,13+) | |
| Where | ai.onnx(9-15,16+) | |
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {parseSplitAttributes, split} from './ops/split';
import {tile} from './ops/tile';
import {parseTransposeAttributes, transpose} from './ops/transpose';
import * as unaryOps from './ops/unary-op';
import {where} from './ops/where';
import {ComputeContext} from './types';

export type RunFunction = (context: ComputeContext, attribute?: unknown) => void;
Expand Down Expand Up @@ -112,4 +113,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Tile', [tile]],
['Transpose', [transpose, parseTransposeAttributes]],
['Where', [where]],
]);
35 changes: 35 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,16 @@ export interface IndicesHelper {
*/
readonly indicesToOffset: (varIndices: string) => string;

/**
* WGSL code of an `u32` expression for getting original offset from broadcasted indices.
*
* @param varIndices - a `type.indices` expression representing the output indices.
* @param output - output IndicesHelper.
*
* @returns an `u32` expression
*/
readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string;

/**
* WGSL code of generating an indices literal
*
Expand Down Expand Up @@ -262,6 +272,7 @@ const createIndicesHelper =
const implementationUsed = {
offsetToIndices: false,
indicesToOffset: false,
broadcastedIndicesToOffset: false,
set: false,
setByIndices: false,
get: false,
Expand Down Expand Up @@ -310,6 +321,26 @@ const createIndicesHelper =
return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
};

const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {};
const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => {
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
implementationUsed.broadcastedIndicesToOffset = true;
const implKey = `${output.name}broadcastedIndicesTo${name}Offset`;
if (implKey in broadcastedIndicesToOffsetImplementation) {
return `${implKey}(${varIndices})`;
}
const offsets = [];
for (let i = shape.length - 1; i >= 0; i--) {
const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length);
offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`);
}
broadcastedIndicesToOffsetImplementation[implKey] =
`fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}`;

return `${implKey}(${varIndices})`;
};

const indices = (...init: ReadonlyArray<number|string>) =>
rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;

Expand Down Expand Up @@ -462,6 +493,9 @@ const createIndicesHelper =
if (implementationUsed.indicesToOffset) {
impls.push(indicesToOffsetImplementation);
}
if (implementationUsed.broadcastedIndicesToOffset) {
Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl));
}
if (implementationUsed.set) {
impls.push(setImplementation);
}
Expand All @@ -482,6 +516,7 @@ const createIndicesHelper =
type,
offsetToIndices,
indicesToOffset,
broadcastedIndicesToOffset,
indices,
indicesGet,
indicesSet,
Expand Down
110 changes: 110 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';

const createWhereOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
typeOutput: number) => {
const outputSize = ShapeUtil.size(dimsOutput);
const vecSize = Math.ceil(outputSize / 4);

const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);

let assignment: string;
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
if (!isBroadcast) {
assignment = output.setByOffset(
'global_idx',
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
} else {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
const expressionA = `aData[indexA${x}][componentA${x}]`;
const expressionB = `bData[indexB${x}][componentB${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let indexC${x} = offsetC${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
let componentB${x} = offsetB${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
if (typeOutput === DataType.bool) {
assignment = `
var data = vec4<u32>(0);
${singleAssignment('data', 0, 'u32')}
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
outputData[global_idx] = dot(vec4<u32>(0x1, 0x100, 0x10000, 0x1000000), vec4<u32>(data));`;
} else {
assignment = `
${singleAssignment('outputData[global_idx]', 0)}
${singleAssignment('outputData[global_idx]', 1)}
${singleAssignment('outputData[global_idx]', 2)}
${singleAssignment('outputData[global_idx]', 3)}
`;
}
}

return `
${shaderHelper.declareVariables(c, a, b, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
${assignment}
}`;
};

const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const dimsA = inputs[1].dims;
const dimsB = inputs[2].dims;
const dimsC = inputs[0].dims;
const outputDataType = inputs[1].dataType;

const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
let outputShape = dimsA;
let outputSize = ShapeUtil.size(dimsA);
// TODO: deal with zero-sized tensors (eg. dims=[1,0])

if (isBroadcast) {
const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false);
if (!calculatedShape) {
throw new Error('Can\'t perform where op on the given tensors');
}
outputShape = calculatedShape;
outputSize = ShapeUtil.size(outputShape);
}

return {
...metadata,
getShaderSource: (shaderHelper) =>
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}],
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)})
};
};

const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => {
const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default];
const metadata: ProgramMetadata = {name, inputTypes};
return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)};
};

export const where = (context: ComputeContext): void => {
context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where'));
};
172 changes: 172 additions & 0 deletions js/web/test/data/ops/where.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
[
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] float32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0],
"dims": [8],
"type": "float32"
},
{
"data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] int32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4, 8, 7, 2, 4, 8, 7, 1],
"dims": [8],
"type": "int32"
},
{
"data": [1, 3, 9, 6, 1, 3, 9, 2],
"dims": [8],
"type": "int32"
}
],
"outputs": [
{
"data": [4, 3, 7, 6, 4, 3, 7, 2],
"dims": [8],
"type": "int32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] uint32 T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [4, 8, 7, 2, 4, 8, 7, 1],
"dims": [8],
"type": "uint32"
},
{
"data": [1, 4294967295, 9, 6, 1, 3, 9, 2],
"dims": [8],
"type": "uint32"
}
],
"outputs": [
{
"data": [4, 4294967295, 7, 6, 4, 3, 7, 2],
"dims": [8],
"type": "uint32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3] T[3] T[3] bool T[3] ",
"inputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "bool"
},
{
"data": [true, true, true, true, true, true, true, true],
"dims": [8],
"type": "float32"
},
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "float32"
}
],
"outputs": [
{
"data": [true, false, true, false, true, false, true, false],
"dims": [8],
"type": "float32"
}
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[3 3] T[3 3] T[1] float32 broadcast",
"inputs": [
{
"data": [true, true, true, true, true, false, false, false, false],
"dims": [3, 3],
"type": "bool"
},
{
"data": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"dims": [3, 3],
"type": "float32"
},
{
"data": [-1.0],
"dims": [1],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 1, 2, 3, 4, -1, -1, -1, -1],
"dims": [3, 3],
"type": "float32"
}
]
}
]
}
]
Loading