Skip to content

Commit

Permalink
[JS/WebGPU] support Concat.int32 operator
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed Aug 15, 2023
1 parent 24e0bd3 commit 412cb26
Show file tree
Hide file tree
Showing 5 changed files with 434 additions and 22 deletions.
9 changes: 5 additions & 4 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ import {ComputeContext, GpuData, GpuDataType, ProgramInfo, ProgramInfoLoader} fr
*
*/
const getProgramInfoUniqueKey =
(programInfo: ProgramInfo|ProgramInfoLoader, inputTensorShapes: ReadonlyArray<TensorView['dims']>,
(programInfo: ProgramInfo|ProgramInfoLoader, inputs: readonly TensorView[],
inputGpuDataTypes: readonly GpuDataType[]): string => {
const inputTensorShapesToString = inputTensorShapes.map(d => `${d.join(',')}`).join('_');
const inputDataTypesToString = inputs.map(i => `${i.dataType}`).join('_');
const inputTensorShapesToString = inputs.map(i => `${i.dims.join(',')}`).join('_');
const inputGpuDataTypesToString = inputGpuDataTypes.join('_');
let key = programInfo.name;
if (programInfo.cacheHint) {
key += '[' + programInfo.cacheHint + ']';
}
key += ':' + inputTensorShapesToString + ';' + inputGpuDataTypesToString;
key += ':' + inputDataTypesToString + ';' + inputTensorShapesToString + ';' + inputGpuDataTypesToString;
return key;
};

Expand Down Expand Up @@ -221,7 +222,7 @@ export class WebGpuBackend {
inputDatas[i] = gpuData;
}

const key = getProgramInfoUniqueKey(program, inputs.map(i => i.dims), inputDatas.map(i => i.type));
const key = getProgramInfoUniqueKey(program, inputs, inputDatas.map(i => i.type));
let artifact = this.programManager.getArtifact(key);
const programInfo = artifact ?
artifact.programInfo :
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ const createConcatProgramInfo =
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let indices = ${output.offsetToIndices('global_idx')};
var indices = ${output.offsetToIndices('global_idx')};
let inputIndex = calculateInputIndex(${indicesAxis});
if (inputIndex != 0u) {
Expand Down
Loading

0 comments on commit 412cb26

Please sign in to comment.