From b5ee4ac7609cb2d9115a3a2f6798da26dea99624 Mon Sep 17 00:00:00 2001 From: xhcao Date: Sat, 9 Nov 2024 03:02:36 +0800 Subject: [PATCH] [js/webgpu] support GridSample operator (#22652) ### Description ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../lib/wasm/jsep/webgpu/ops/grid-sample.ts | 279 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 16 +- .../providers/js/js_execution_provider.cc | 6 + .../providers/js/operators/grid_sample.cc | 30 ++ .../core/providers/js/operators/grid_sample.h | 32 ++ 7 files changed, 358 insertions(+), 8 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts create mode 100644 onnxruntime/core/providers/js/operators/grid_sample.cc create mode 100644 onnxruntime/core/providers/js/operators/grid_sample.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 5f329b5ff8b39..0ad4c6db8cd97 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -56,6 +56,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| GridSample | ai.onnx(16-19); com.ms.internal.nhwc(16-19) | | | GroupQueryAttention | com.microsoft(1+) | | | HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19-20,21+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 09c786daa3fcd..669b8fcb5a7cf 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -19,6 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather'; import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized'; import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements'; import { gemm, parseGemmAttributes } from './ops/gemm'; +import { gridSample, parseGridSampleAttributes } from './ops/grid-sample'; import { groupQueryAttention } from './ops/group-query-attention'; import { instanceNorm } from './ops/instance-norm'; import { layerNorm } from './ops/layer-norm'; @@ -104,6 +105,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['GridSample', [gridSample, parseGridSampleAttributes]], ['GroupQueryAttention', [groupQueryAttention]], ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts new file mode 100644 index 0000000000000..50c71472434ad --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/grid-sample.ts @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { DataType } from '../../../wasm-common'; +import { TensorView } from '../../tensor-view'; +import { ShapeUtil } from '../../util'; +import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; + +import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common'; + +let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW +type Mode = 'bilinear' | 'nearest' | 'bicubic'; +type PaddingMode = 'zeros' | 'border' | 'reflection'; +type Format = 'NHWC' | 'NCHW'; +export interface GridSampeAttributes extends AttributeWithCacheKey { + alignCorners: number; + mode: Mode; + paddingMode: PaddingMode; + format: Format; +} + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 4) { + throw new Error('only 4-D tensor is supported.'); + } + if (inputs[0].dims.length !== inputs[1].dims.length) { + throw new Error('input dimensions must be equal to grid dimensions'); + } + + if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) { + throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`); + } + + if (inputs[0].dims[0] !== inputs[1].dims[0]) { + throw new Error('grid batch size must match input batch size'); + } +}; + +const gsGetCubicCoeffs = ` + fn gs_get_cubic_coeffs(x: f32) -> vec4 { + let cubic_alpha = -0.75f; + let x_abs = abs(x); + var coeffs: vec4; + coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha); + coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1); + coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1); + coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha); + return coeffs; + } +`; + +const gsBicubicInterpolate = (dataType: string): string => ` + fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} { + var v: vec4; + var coeffs = gs_get_cubic_coeffs(x); + for (var i = 0; i < 4; i++) { + v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3]; + } + coeffs = gs_get_cubic_coeffs(y); + let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]); + return pixel; + } +`; + +const gsDenormalize = (attributes: GridSampeAttributes): string => ` + fn gs_denormalize(n: f32, length: i32) -> f32 { + ${ + attributes.alignCorners === 0 + ? ` + // alignCorners: false => [-1, 1] to [-0.5, length - 0.5] + return ((n + 1.0) * f32(length) - 1.0) / 2.0; + ` + : ` + // alignCorners: true => [-1, 1] to [0, length - 1] + return (n + 1.0) / 2.0 * (f32(length - 1)); + ` + } + } +`; + +const gsReflect = (attributes: GridSampeAttributes): string => ` + ${ + attributes.paddingMode === 'reflection' + ? ` + fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 { + var dx = 0.0; + var fx = f32(x); + let range = x_max - x_min; + if (fx < x_min) { + dx = x_min - fx; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_min + r; + } else { + fx = x_max - r; + } + } else if (fx > x_max) { + dx = fx - x_max; + let n = u32(dx / range); + let r = dx - f32(n) * range; + if (n % 2 == 0) { + fx = x_max - r; + } else { + fx = x_min + r; + } + } + return u32(fx); + }` + : '' + } +`; + +const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + ` + fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4) -> ${dataType} { + var pixel = ${dataType}(0); + var indices = vec4(0); + indices[${idxN}] = batch; + indices[${idxC}] = channel;` + + (() => { + switch (attributes.paddingMode) { + case 'zeros': + return ` + if (r >= 0 && r < H && c >=0 && c < W) { + indices[${idxH}] = u32(r); + indices[${idxW}] = u32(c); + } + `; + case 'border': + return ` + indices[${idxH}] = u32(clamp(r, 0, H - 1)); + indices[${idxW}] = u32(clamp(c, 0, W - 1)); + `; + case 'reflection': + return ` + indices[${idxH}] = gs_reflect(r, border[1], border[3]); + indices[${idxW}] = gs_reflect(c, border[0], border[2]); + `; + default: + throw new Error(`padding mode ${attributes.paddingMode} is not supported`); + } + })() + + ` + return ${input.getByIndices('indices')}; + } +`; + +const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string => + (() => { + switch (attributes.mode) { + case 'nearest': + return ` + let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border); + `; + case 'bilinear': + return ` + let x1 = i32(floor(x)); + let y1 = i32(floor(y)); + let x2 = x1 + 1; + let y2 = y1 + 1; + + let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + + let dx2 = ${dataType}(f32(x2) - x); + let dx1 = ${dataType}(x - f32(x1)); + let dy2 = ${dataType}(f32(y2) - y); + let dy1 = ${dataType}(y - f32(y1)); + let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22); + `; + case 'bicubic': + return ` + let x0 = i32(floor(x)) - 1; + let y0 = i32(floor(y)) - 1; + var p: mat4x4<${dataType}>; + for (var h = 0; h < 4; h++) { + for (var w = 0; w < 4; w++) { + p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border); + } + } + + let dx = x - f32(x0 + 1); + let dy = y - f32(y0 + 1); + let result = gs_bicubic_interpolate(p, dx, dy); + `; + default: + throw new Error(`mode ${attributes.mode} is not supported`); + } + })() + `${output.setByOffset('global_idx', 'result')}`; + +const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => { + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length); + // discard last dimension for using vec2 to access grid data + const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]]; + const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2); + let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]]; + if (attributes.format === 'NHWC') { + outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]]; + [idxN, idxC, idxH, idxW] = [0, 3, 1, 2]; + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const dataType = x.type.value; + const outputSize = ShapeUtil.size(outputShape); + + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: outputSize }, + ...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape), + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)} + ${gsGetCubicCoeffs} + ${gsBicubicInterpolate(dataType)} + ${gsDenormalize(attributes)} + ${gsReflect(attributes)} + ${pixelAtGrid(x, dataType, attributes)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let H_in = i32(uniforms.x_shape[${idxH}]); + let W_in = i32(uniforms.x_shape[${idxW}]); + + ${ + attributes.alignCorners === 0 + ? ` + let x_min = -0.5; + let x_max = f32(W_in) - 0.5; + let y_min = -0.5; + let y_max = f32(H_in) - 0.5; + ` + : ` + let x_min = 0.0; + let x_max = f32(W_in) - 1.0; + let y_min = 0.0; + let y_max = f32(H_in) - 1.0; + ` + }; + let border = vec4(x_min, y_min, x_max, y_max); + + let indices = ${output.offsetToIndices('global_idx')}; + var grid_indices = vec3(indices[${idxN}], indices[${idxH}], indices[${idxW}]); + let nxy = ${grid.getByIndices('grid_indices')}; + var x = gs_denormalize(f32(nxy[0]), W_in); + var y = gs_denormalize(f32(nxy[1]), H_in); + + ${computePixel(output, dataType, attributes)} + }`; + + return { + name: 'GridSample', + shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] }, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{ dims: outputShape, dataType: inputs[0].dataType }], + dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) }, + programUniforms, + }; + }, + getShaderSource, + }; +}; + +export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => { + validateInputs(context.inputs); + context.compute(createGridSampleProgramInfo(context.inputs, attributes)); +}; + +export const parseGridSampleAttributes = (attributes: Record): GridSampeAttributes => + createAttributeWithCacheKey({ + alignCorners: attributes.align_corners as number, + mode: attributes.mode as Mode, + paddingMode: attributes.padding_mode as PaddingMode, + format: attributes.format as Format, + }); diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1f379e0de6165..0e42a1dbcbf08 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -570,14 +570,14 @@ "test_greater_equal_expanded", "test_greater_equal", "test_greater", - // // "test_gridsample_aligncorners_true", - // // "test_gridsample_bicubic", - // // "test_gridsample_bilinear", - // // "test_gridsample_border_padding", - // // "test_gridsample_nearest", - // // "test_gridsample_reflection_padding", - // // "test_gridsample_zeros_padding", - // // "test_gridsample", + "test_gridsample_aligncorners_true", + "test_gridsample_bicubic", + "test_gridsample_bilinear", + "test_gridsample_border_padding", + "test_gridsample_nearest", + "test_gridsample_reflection_padding", + "test_gridsample_zeros_padding", + "test_gridsample", // // "test_gru_batchwise", // // "test_gru_defaults", // // "test_gru_seq_length", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index ca0cfedb7e2f0..19a3029fca28f 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -400,6 +400,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -728,6 +731,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/grid_sample.cc b/onnxruntime/core/providers/js/operators/grid_sample.cc new file mode 100644 index 0000000000000..84eb7df6c5bbe --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "grid_sample.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kMSInternalNHWCDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GridSample, + kOnnxDomain, + 16, 19, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", JsepSupportedDataTypes()) + .TypeConstraint("T2", JsepSupportedFloatTypes()), + GridSample); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/grid_sample.h b/onnxruntime/core/providers/js/operators/grid_sample.h new file mode 100644 index 0000000000000..352decf33dc20 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/grid_sample.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +template +class GridSample : public JsKernel { + public: + GridSample(const OpKernelInfo& info) : JsKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + std::string mode = info.GetAttrOrDefault("mode", "linear"); + std::string padding_mode = info.GetAttrOrDefault("padding_mode", "zeros"); + int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + + JSEP_INIT_KERNEL_ATTRIBUTE(GridSample, ({ + "align_corners" : $1, + "mode" : UTF8ToString($2), + "padding_mode" : UTF8ToString($3), + "format" : $4 ? "NHWC" : "NCHW" + }), + static_cast(align_corners), mode.c_str(), + padding_mode.c_str(), static_cast(channels_last)); + } +}; + +} // namespace js +} // namespace onnxruntime