From ff57ac4f3ddcc535aef30d8e638360a9f9600cec Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 14 Nov 2024 01:13:00 +0800 Subject: [PATCH] [js/webgpu] Add scatterND (#22755) ### Description ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + js/web/lib/wasm/jsep/webgpu/ops/common.ts | 19 +- js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts | 177 +++++++ js/web/test/data/ops/scatternd.jsonc | 472 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../providers/js/js_execution_provider.cc | 8 + .../core/providers/js/operators/scatter_nd.cc | 51 ++ .../core/providers/js/operators/scatter_nd.h | 65 +++ 9 files changed, 793 insertions(+), 3 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts create mode 100644 js/web/test/data/ops/scatternd.jsonc create mode 100644 onnxruntime/core/providers/js/operators/scatter_nd.cc create mode 100644 onnxruntime/core/providers/js/operators/scatter_nd.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 0ad4c6db8cd97..f63cf17aa4df3 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -94,6 +94,7 @@ Do not modify directly.* | Reshape | ai.onnx(5-12,13,14-18,19-20,21+) | no GPU kernel | | Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | RotaryEmbedding | com.microsoft(1+) | | +| ScatterND | ai.onnx(11-12,13-15,16-17,18+) | | | Shape | ai.onnx(1-12,13-14,15-18,19-20,21+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | SimplifiedLayerNormalization | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 669b8fcb5a7cf..28af5d461abe0 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -30,6 +30,7 @@ import { pad } from './ops/pad'; import * as pool from './ops/pool'; import { dequantizeLinear, parseDequantizeLinearAttributes } from './ops/quantize-linear'; import { range } from './ops/range'; +import { scatterND, parseScatterNDAttributes } from './ops/scatter-nd'; import { reduceL1, reduceL2, @@ -140,6 +141,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Relu', [unaryOps.relu]], ['Resize', [resize, parseResizeAttributes]], ['RotaryEmbedding', [rotaryEmbedding]], + ['ScatterND', [scatterND, parseScatterNDAttributes]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c985f98e524a4..0b9173403cd7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -195,7 +195,7 @@ export interface IndicesHelper { /** * whether the helper is for an input, an output or an internal variable. */ - readonly usage: 'input' | 'output' | 'internal'; + readonly usage: 'input' | 'output' | 'atomicOutput' | 'internal'; /** * the rank of the input or output. @@ -733,6 +733,20 @@ export const outputVariable = ( components: 1 | 2 | 3 | 4 = 1, ): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'output', components); +/** + * Create a IndicesHelper for an atomic output. + * + * @param name - the name of the output. + * @param type - the tensor type of the output. + * @param shapeOrRank - the tensor shape or the rank of the output. + * @returns an IndicesHelper for the output. + */ +export const atomicOutputVariable = ( + name: string, + type: number, + shapeOrRank: number | readonly number[], +): IndicesHelper => createIndicesHelper(name, type, shapeOrRank, 'atomicOutput', 1); + /** * Create a IndicesHelper for an internal variable. * @@ -905,9 +919,8 @@ class ShaderHelperImpl implements ShaderHelper { } this.variables.push(variable); this.appendVariableUniforms(variable); - const access = variable.usage === 'input' ? 'read' : 'read_write'; - const storageType = variable.type.storage; + const storageType = variable.usage === 'atomicOutput' ? `atomic` : variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts new file mode 100644 index 0000000000000..8c24232d63c0c --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/scatter-nd.ts @@ -0,0 +1,177 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { + atomicOutputVariable, + createTensorShapeVariables, + inputVariable, + outputVariable, + ShaderHelper, +} from './common'; + +export interface ScatterNDAttributes extends AttributeWithCacheKey { + reduction: string; +} + +type ReductionType = 'i32' | 'u32' | 'f32'; + +const atomicReductionSnippet = (reduction: string, ptr: string, v: string, type: ReductionType) => { + if (reduction !== 'none' && type !== 'i32' && type !== 'u32' && type !== 'f32') { + throw new Error(`Input ${type} is not supported with reduction ${reduction}.`); + } + + const floatStart = `{ + var oldValue = 0; + loop { + let newValueF32 =`; + const floatEnd = `; + let newValue = bitcast(newValueF32); + let res = atomicCompareExchangeWeak(&${ptr}, oldValue, newValue); + if res.exchanged { + break; + } + oldValue = res.old_value; + } + }`; + + switch (reduction) { + case 'none': + return `${ptr}=${v};`; + case 'add': + if (type === 'i32' || type === 'u32') { + return `atomicAdd(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicAdd only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}bitcast<${type}>(oldValue) + (${v})${floatEnd}`; + } + case 'max': + if (type === 'i32' || type === 'u32') { + return `atomicMax(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMax only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return ` + ${floatStart}max(bitcast(oldValue), (${v}))${floatEnd}`; + } + case 'min': + if (type === 'i32' || type === 'u32') { + return `atomicMin(&${ptr}, bitcast<${type}>(${v}));`; + } else { + // atomicMin only supports uint/int type. For float, we use + // atomicCompareExchangeWeak to simulate. + return `${floatStart}min(bitcast<${type}>(oldValue), (${v}))${floatEnd}`; + } + case 'mul': + // atomicMul is not supported, we use atomicCompareExchangeWeak to simulate. + return `${floatStart}(bitcast<${type}>(oldValue) * (${v}))${floatEnd}`; + + default: + throw new Error(`Reduction ${reduction} is not supported.`); + } +}; + +const createScatterNDProgramInfo = (inputs: readonly TensorView[], attributes: ScatterNDAttributes): ProgramInfo => { + const inputShape = inputs[0].dims; + const indicesShape = inputs[1].dims; + const outputShape = inputShape; + // TODO: support bool with components 4. + const components = 1; + const outputSize = Math.ceil(ShapeUtil.size(indicesShape) / components); + const lastIndexDimension = indicesShape[indicesShape.length - 1]; + const numUpdatesElements = ShapeUtil.sizeFromDimension(inputShape, lastIndexDimension); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + { type: DataType.uint32, data: lastIndexDimension }, + { type: DataType.uint32, data: numUpdatesElements }, + ...createTensorShapeVariables(inputs[1].dims, inputs[2].dims, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const indices = inputVariable('indices', inputs[1].dataType, inputs[1].dims.length); + const updates = inputVariable('updates', inputs[2].dataType, inputs[2].dims.length, components); + const output = + attributes.reduction !== 'none' && attributes.reduction !== '' + ? atomicOutputVariable('output', inputs[0].dataType, outputShape.length) + : outputVariable('output', inputs[0].dataType, outputShape.length, components); + + return ` + ${shaderHelper + .registerUniform('output_size', 'u32') + .registerUniform('last_index_dimension', 'u32') + .registerUniform('num_updates_elements', 'u32') + .declareVariables(indices, updates, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var data_offset = 0u; + let indices_start = uniforms.last_index_dimension * global_idx; + let indices_end = indices_start + uniforms.last_index_dimension; + for (var i = indices_start; i < indices_end; i++) { + var index = i32(indices[i].x); + ${ + inputs[0].dims.length === 1 + ? ` + let element_count_dim = uniforms.output_strides; + let dim_value = uniforms.output_shape;` + : ` + let element_count_dim = uniforms.output_strides[i - indices_start]; + let dim_value = uniforms.output_shape[i - indices_start + uniforms.last_index_dimension];` + } + if (index >= 0) { + if (index >= i32(dim_value)) { + index = i32(dim_value - 1); + } + } else { + if (index < -i32(dim_value)) { + index = 0; + } else { + index += i32(dim_value); + } + } + data_offset += u32((u32(index) * element_count_dim)); + } + + for (var i = 0u; i < uniforms.num_updates_elements; i++) { + let value = updates[uniforms.num_updates_elements * global_idx + i]; + ${atomicReductionSnippet( + attributes.reduction, + 'output[data_offset + i]', + 'value', + output.type.value as ReductionType, + )} + } + + }`; + }; + return { + name: 'ScatterND', + shaderCache: { + hint: `${attributes.cacheKey}_${attributes.reduction}`, + inputDependencies: ['rank', 'rank'], + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }), + getShaderSource, + }; +}; + +export const parseScatterNDAttributes = (attributes: Record): ScatterNDAttributes => + createAttributeWithCacheKey({ reduction: attributes.reduction as string }); + +export const scatterND = (context: ComputeContext, attributes: ScatterNDAttributes): void => { + context.compute(createScatterNDProgramInfo(context.inputs, attributes), { + inputs: [context.inputs[1], context.inputs[2]], + outputs: [], + }); +}; diff --git a/js/web/test/data/ops/scatternd.jsonc b/js/web/test/data/ops/scatternd.jsonc new file mode 100644 index 0000000000000..5135bb9e4d3a5 --- /dev/null +++ b/js/web/test/data/ops/scatternd.jsonc @@ -0,0 +1,472 @@ +[ + { + "name": "ScatterND int32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 11, 3, 10, 9, 6, 7, 12], + "dims": [8], + "type": "int32" + } + ] + }, + { + "name": "int32", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "int64" + }, + { + "data": [ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, + 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131 + ], + "dims": [2, 4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, + 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, + 131, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [4, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND float32", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 13 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND add int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9, 10, 11, 12], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 13, 3, 14, 14, 6, 7, 20], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND add float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "add", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 13.5, 3.0999999046325684, 14.699999809265137, 14.40000057220459, 6.099999904632568, + 7.800000190734863, 21.399999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND mul int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND min int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND max int32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "int32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "int32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [11, 113, 31, 102, 91, 61, 78, 125], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "ScatterND mul float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 24.860000610351562, 3.0999999046325684, 45.89999771118164, 48.230003356933594, + 6.099999904632568, 7.800000190734863, 111.24999237060547 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND min float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "min", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 2.200000047683716, 3.0999999046325684, 4.5, 5.300000190734863, 6.099999904632568, + 7.800000190734863, 8.899999618530273 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND max float32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "max", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "float32", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.100000023841858, 11.300000190734863, 3.0999999046325684, 10.199999809265137, 9.100000381469727, + 6.099999904632568, 7.800000190734863, 12.5 + ], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ScatterND float16", + "operator": "ScatterND", + "attributes": [], + "opset": { "domain": "", "version": 11 }, + "cases": [ + { + "name": "float16", + "inputs": [ + { + "data": [1.1, 2.2, 3.1, 4.5, 5.3, 6.1, 7.8, 8.9], + "dims": [8], + "type": "float16" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [9.1, 10.2, 11.3, 12.5], + "dims": [1, 4], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.1, 11.3, 3.1, 10.2, 9.1, 6.1, 7.8, 12.5], + "dims": [8], + "type": "float16" + } + ] + } + ] + }, + { + "name": "ScatterND mul uint32", + "operator": "ScatterND", + "attributes": [{ "name": "reduction", "data": "mul", "type": "string" }], + "opset": { "domain": "", "version": 16 }, + "cases": [ + { + "name": "uint32", + "inputs": [ + { + "data": [11, 22, 31, 45, 53, 61, 78, 89], + "dims": [8], + "type": "uint32" + }, + { + "data": [4, 3, 1, 7], + "dims": [1, 4, 1], + "type": "int64" + }, + { + "data": [91, 102, 113, 125], + "dims": [1, 4], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [11, 2486, 31, 4590, 4823, 61, 78, 11125], + "dims": [8], + "type": "uint32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index be2e0002ca748..45fb771ee13bb 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1396,6 +1396,7 @@ "pow-big-number.jsonc", "reshape.jsonc", "rotary-embedding.jsonc", + "scatternd.jsonc", "simplified-layer-norm.jsonc", "skip-layer-norm.jsonc", "skip-simplified-layer-norm.jsonc", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 19a3029fca28f..c0d62bf47a0dd 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -402,6 +402,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 15, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ScatterND); std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -734,6 +738,10 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.cc b/onnxruntime/core/providers/js/operators/scatter_nd.cc new file mode 100644 index 0000000000000..e9edb7f58fe5e --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.cc @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "core/providers/js/js_data_types.h" +#include "scatter_nd.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + ScatterND, + kOnnxDomain, + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 16, + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 13, + 15, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ScatterND, + kOnnxDomain, + 11, + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedDataTypes()), + ScatterND); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/scatter_nd.h b/onnxruntime/core/providers/js/operators/scatter_nd.h new file mode 100644 index 0000000000000..8c81c62d71fe7 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/scatter_nd.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" +#include "core/framework/data_transfer_manager.h" + +namespace onnxruntime { +namespace js { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +class ScatterND : public JsKernel { + public: + ScatterND(const OpKernelInfo& info) : JsKernel(info) { + std::string reduction = info.GetAttrOrDefault("reduction", "none"); + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees " + << "to be correct if indices are not duplicated."; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on webgpu when opset <= 13."); + } + + JSEP_INIT_KERNEL_ATTRIBUTE(ScatterND, ({ + "reduction" : UTF8ToString($1), + }), + reduction.c_str()); + } + + Status Compute(OpKernelContext* context) const override { + const Tensor* X = context->Input(0); + if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + const TensorShape& X_shape = X->Shape(); + + Tensor* Y = context->Output(0, X_shape); + const void* source = X->DataRaw(); + void* target = Y->MutableDataRaw(); + // If source and target pointers are not equal (non-inplace operation), we need to copy the data. + if (target != source) { + ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*X, *Y)); + } + return ComputeInternal(context); + } + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + +} // namespace js +} // namespace onnxruntime