Skip to content

Commit

Permalink
[JS/WebGPU] Preserve zero size input tensor dims. (#19737)
Browse files Browse the repository at this point in the history
### Description
For Concat operation, the zero-size input tensor shape need to be
preserved and, unlike non-zero tensors, the dims are not constrained to
match other input tensors' dims.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Mar 8, 2024
1 parent 6c3bed6 commit 24b72d2
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 77 deletions.
146 changes: 69 additions & 77 deletions js/web/lib/wasm/jsep/webgpu/ops/concat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
if (!inputs || inputs.length < 1) {
throw new Error('too few inputs');
}

const inputType = inputs[0].dataType;
const inputDimensionality = inputs[0].dims.length;

for (const input of inputs) {
const referenceIndex = 0;
const referenceInput = inputs[referenceIndex];
const inputType = referenceInput.dataType;
const inputRank = referenceInput.dims.length;
inputs.forEach((input, i) => {
if (i === referenceIndex) {
return;
}
// make sure types of all inputs match
if (input.dataType !== inputType) {
throw new Error('input tensors should be one type');
}

// make sure the dimensionality of all inputs are the same
if (input.dims.length !== inputDimensionality) {
if (input.dims.length !== inputRank) {
throw new Error('input tensors should have the same shape');
}
}
input.dims.forEach((dim, i) => {
if (i !== axis && dim !== referenceInput.dims[i]) {
throw new Error('non concat dimensions must match');
}
});
});
};

const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
Expand Down Expand Up @@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
return codeLines.join('\n');
};

const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
const inputShape = inputs[0].dims.slice();
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
throw new Error('axis specified for concat doesn\'t match input dimensionality');
}
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
// ensure all of the non-concatenated axes match each other
// calculate the shape of the output tensor while we do that
const outputShape = inputShape.slice(0);
for (let i = 1; i < inputs.length; i++) {
const dataNShape = inputs[i].dims.slice();
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
// add to the placeholder for computing output shape
if (axisIndex === adjustedAxis) {
outputShape[adjustedAxis] += dataNShape[axisIndex];
const createConcatProgramInfo =
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
// ensure all non-cancatenated axes match each other
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
throw new Error('non concat dimensions must match');
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
}
}

const outputSize = ShapeUtil.size(outputShape);

const sizeInConcatAxis = new Array<number>(inputs.length);
const inputVars = new Array<IndicesHelper>(inputs.length);
const dataType = inputs[0].dataType;

let previousSum = 0;
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
const inputRanks = [];
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
for (let i = 0; i < inputs.length; ++i) {
previousSum += inputs[i].dims[adjustedAxis];
sizeInConcatAxis[i] = previousSum;
inputRanks.push(inputs[i].dims.length);
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
inputDependencies.push('rank');
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
}
for (let i = 0; i < inputs.length; ++i) {
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
}
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(outputShape));

const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
const output = outputVariable('output', dataType, outputShape.length);
const indicesAxis = output.indicesGet('indices', adjustedAxis);
const sizeInConcatAxisStr =
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
const getShaderSource = (shaderHelper: ShaderHelper) => `
${(() => {
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
shaderHelper.registerUniform('outputSize', 'u32');
for (let i = 0; i < inputs.length; i++) {
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
}
return shaderHelper.declareVariables(...inputVars, output);
})()}
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
Expand All @@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
${assignOutputData(inputVars, output)}
}`;

return {
name: 'Concat',
shaderCache: {hint: `${axis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};
return {
name: 'Concat',
shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms,
}),
getShaderSource,
};
};

export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
validateInputs(context.inputs);
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
validateInputs(inputs, adjustedAxis);
const outputShape = inputShape.slice();
outputShape[adjustedAxis] =
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
// 0 length tensors are valid for concat, remove them
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
context.compute(
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
};

export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
Expand Down
80 changes: 80 additions & 0 deletions js/web/test/data/ops/concat_zero-sized.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -557,5 +557,85 @@
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 0,
"type": "int"
}
],
"cases": [
{
"name": "Some but not all input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [1],
"dims": [1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Concat 2D axis=1; Preserve dims",
"operator": "Concat",
"attributes": [
{
"name": "axis",
"data": 1,
"type": "int"
}
],
"cases": [
{
"name": "All input tensors are zero-sized",
"inputs": [
{
"data": [],
"dims": [0, 0],
"type": "float32"
},
{
"data": [],
"dims": [0, 1],
"type": "float32"
},
{
"data": [],
"dims": [0, 2],
"type": "float32"
},
{
"data": [],
"dims": [0, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [],
"dims": [0, 6],
"type": "float32"
}
]
}
]
}
]

0 comments on commit 24b72d2

Please sign in to comment.