Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Optimize Gather op #17625

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ const createIndicesHelper =

const getByIndicesImplementation = rank < 2 ? '' : `
fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
return ${name}[i2o_${name}(indices)];
return ${getByOffset(`i2o_${name}(indices)`)};
}`;

const getImplementation = rank < 2 ? '' : (() => {
Expand Down
91 changes: 41 additions & 50 deletions js/web/lib/wasm/jsep/webgpu/ops/gather.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';

import {ShaderHelper} from './common';
import {inputVariable, outputVariable, ShaderHelper} from './common';

export interface GatherAttributes extends AttributeWithCacheKey {
axis: number;
Expand All @@ -30,63 +29,55 @@ const createGatherProgramInfo =
const outputShape = inputShape.slice(0);
outputShape.splice(axis, 1, ...indicesShape);

const inputDataType = inputs[0].dataType;
const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1);
const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1;
const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1;
const blockSize = elementSize * block;
const M = ShapeUtil.sizeToDimension(inputShape, axis);
const N = ShapeUtil.size(indicesShape);
const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize;
const gatheredBatchElements = N * block * elementSize;
const axisDimLimit = inputShape[axis];
const outputSize = ShapeUtil.size(outputShape);

const data = inputVariable('data', inputs[0].dataType, inputs[0].dims);
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims);
const output = outputVariable('output', inputs[0].dataType, outputShape);
const calcDataIndices = (): string => {
const indicesRank = indicesShape.length;
let calcStr = `var indicesIndices = ${indices.type.indices}(0);`;
for (let i = 0; i < indicesRank; i++) {
calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${
outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`;
}
calcStr += `
var idx = ${indices.getByIndices('indicesIndices')};
if (idx < 0) {
idx = idx + ${axisDimLimit};
}
var dataIndices = ${data.type.indices}(0);
`;
for (let i = 0, j = 0; i < inputRank; i++) {
if (i === axis) {
calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`;
j += indicesRank;
} else {
calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${
outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`;
j++;
}
}
return calcStr;
};

const inputSize = ShapeUtil.size(inputShape) * elementSize;
const outputSize = ShapeUtil.size(outputShape) * elementSize;

const totalGathers = M * N;
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
const N: u32 = ${N};
const elementSize: u32 = ${elementSize};
const indicesElementSize: u32 = ${indicesElementSize};

@group(0) @binding(0) var<storage, read> input : array<u32>;
@group(0) @binding(1) var<storage, read> inputIndices : array<i32>;
@group(0) @binding(2) var<storage, read_write> output: array<u32>;

${shaderHelper.mainStart()}
let batch: u32 = global_idx / N;
let i: u32 = global_idx % N;

let srcOffsetBatch: u32 = batch * ${dataBatchElements};
let dstOffsetBatch: u32 = batch * ${gatheredBatchElements};
var idx = inputIndices[i * indicesElementSize];
if (idx < 0) {
idx = idx + ${axisDimLimit};
}

let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize};
let dstOffset = dstOffsetBatch + i * ${blockSize};
if (srcOffset >= ${inputSize}) {
return;
}
if (dstOffset >= ${outputSize}) {
return;
}
for (var j: u32 = 0; j < ${blockSize}; j++) {
output[dstOffset + j] = input[srcOffset + j];
}
}`;
${shaderHelper.declareVariables(data, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
${calcDataIndices()};
let value = ${data.getByIndices('dataIndices')};
${output.setByOffset('global_idx', 'value')};
}`;
return {
...metadata,
outputs: [
{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default},
],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)})
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};

Expand Down