Skip to content

Commit

Permalink
[js/webgpu] Add scatterND (#22755)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
axinging authored Nov 13, 2024
1 parent bc2b1b5 commit ff57ac4
Show file tree
Hide file tree
Showing 9 changed files with 793 additions and 3 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ Do not modify directly.*
| Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel |
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
| RotaryEmbedding | com.microsoft(1+) | |
| ScatterND | ai.onnx(11-12,13-15,16-17,18+) | |
| Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix |
| Sigmoid | ai.onnx(6-12,13+) | |
| SimplifiedLayerNormalization | ai.onnx(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import { pad } from './ops/pad';
import * as pool from './ops/pool';
import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear';
import { range } from './ops/range';
import { scatterND, parseScatterNDAttributes } from './ops/scatter-nd';
import {
reduceL1,
reduceL2,
Expand Down Expand Up @@ -140,6 +141,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
['RotaryEmbedding', [rotaryEmbedding]],
['ScatterND', [scatterND, parseScatterNDAttributes]],
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
Expand Down
19 changes: 16 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ export interface IndicesHelper {
/**
* whether the helper is for an input, an output or an internal variable.
*/
readonly usage: 'input' | 'output' | 'internal';
readonly usage: 'input' | 'output' | 'atomicOutput' | 'internal';

/**
* the rank of the input or output.
Expand Down Expand Up @@ -733,6 +733,20 @@ export const outputVariable = (
components: 1 | 2 | 3 | 4 = 1,
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components);

/**
* Create a IndicesHelper for an atomic output.
*
* @param name - the name of the output.
* @param type - the tensor type of the output.
* @param shapeOrRank - the tensor shape or the rank of the output.
* @returns an IndicesHelper for the output.
*/
export const atomicOutputVariable = (
name: string,
type: number,
shapeOrRank: number | readonly number[],
): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'atomicOutput', 1);

/**
* Create a IndicesHelper for an internal variable.
*
Expand Down Expand Up @@ -905,9 +919,8 @@ class ShaderHelperImpl implements ShaderHelper {
}
this.variables.push(variable);
this.appendVariableUniforms(variable);

const access = variable.usage === 'input' ? 'read' : 'read_write';
const storageType = variable.type.storage;
const storageType = variable.usage === 'atomicOutput' ? `atomic<i32>` : variable.type.storage;
return `@group(0) @binding(${bindingIndex}) var<storage, ${access}> ${variable.name}: array<${storageType}>;`;
}

Expand Down
177 changes: 177 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';

import {
atomicOutputVariable,
createTensorShapeVariables,
inputVariable,
outputVariable,
ShaderHelper,
} from './common';

export interface ScatterNDAttributes extends AttributeWithCacheKey {
reduction: string;
}

type ReductionType = 'i32' | 'u32' | 'f32';

const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => {
if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') {
throw new Error(`Input ${type} is not supported with reduction ${reduction}.`);
}

const floatStart = `{
var oldValue = 0;
loop {
let newValueF32 =`;
const floatEnd = `;
let newValue = bitcast<i32>(newValueF32);
let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue);
if res.exchanged {
break;
}
oldValue = res.old_value;
}
}`;

switch (reduction) {
case 'none':
return `${ptr}=${v};`;
case 'add':
if (type === 'i32' || type === 'u32') {
return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicAdd only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `
${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`;
}
case 'max':
if (type === 'i32' || type === 'u32') {
return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicMax only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `
${floatStart}max(bitcast<f32>(oldValue), (${v}))${floatEnd}`;
}
case 'min':
if (type === 'i32' || type === 'u32') {
return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`;
} else {
// atomicMin only supports uint/int type. For float, we use
// atomicCompareExchangeWeak to simulate.
return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`;
}
case 'mul':
// atomicMul is not supported, we use atomicCompareExchangeWeak to simulate.
return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`;

default:
throw new Error(`Reduction ${reduction} is not supported.`);
}
};

const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => {
const inputShape = inputs[0].dims;
const indicesShape = inputs[1].dims;
const outputShape = inputShape;
// TODO: support bool with components 4.
const components = 1;
const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components);
const lastIndexDimension = indicesShape[indicesShape.length - 1];
const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension);

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: lastIndexDimension },
{ type: DataType.uint32, data: numUpdatesElements },
...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape),
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length);
const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components);
const output =
attributes.reduction !== 'none' && attributes.reduction !== ''
? atomicOutputVariable('output', inputs[0].dataType, outputShape.length)
: outputVariable('output', inputs[0].dataType, outputShape.length, components);

return `
${shaderHelper
.registerUniform('output_size', 'u32')
.registerUniform('last_index_dimension', 'u32')
.registerUniform('num_updates_elements', 'u32')
.declareVariables(indices, updates, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
var data_offset = 0u;
let indices_start = uniforms.last_index_dimension * global_idx;
let indices_end = indices_start + uniforms.last_index_dimension;
for (var i = indices_start; i < indices_end; i++) {
var index = i32(indices[i].x);
${
inputs[0].dims.length === 1
? `
let element_count_dim = uniforms.output_strides;
let dim_value = uniforms.output_shape;`
: `
let element_count_dim = uniforms.output_strides[i - indices_start];
let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];`
}
if (index >= 0) {
if (index >= i32(dim_value)) {
index = i32(dim_value - 1);
}
} else {
if (index < -i32(dim_value)) {
index = 0;
} else {
index += i32(dim_value);
}
}
data_offset += u32((u32(index) * element_count_dim));
}
for (var i = 0u; i < uniforms.num_updates_elements; i++) {
let value = updates[uniforms.num_updates_elements * global_idx + i];
${atomicReductionSnippet(
attributes.reduction,
'output[data_offset + i]',
'value',
output.type.value as ReductionType,
)}
}
}`;
};
return {
name: 'ScatterND',
shaderCache: {
hint: `${attributes.cacheKey}_${attributes.reduction}`,
inputDependencies: ['rank', 'rank'],
},
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
};
};

export const parseScatterNDAttributes = (attributes: Record<string, unknown>): ScatterNDAttributes =>
createAttributeWithCacheKey({ reduction: attributes.reduction as string });

export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => {
context.compute(createScatterNDProgramInfo(context.inputs, attributes), {
inputs: [context.inputs[1], context.inputs[2]],
outputs: [],
});
};
Loading

0 comments on commit ff57ac4

Please sign in to comment.