Skip to content

Commit

Permalink
Replaced caloffset* function with broadcastedIndicesToOffset.
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Nov 3, 2023
1 parent 536d4fd commit 47c2be5
Showing 1 changed file with 4 additions and 23 deletions.
27 changes: 4 additions & 23 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,6 @@ const createBinaryOpProgramShader =
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4, vectorize);
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4, vectorize);
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4, vectorize);
if (doBroadcast) {
const calcOffsetImpl = (indicesHelper: IndicesHelper) => `
fn calcOffset${indicesHelper.name}(outputIndices: ${output.type.indices}) -> u32{
var retval : u32 = 0;
${
indicesHelper.rank > 1 ? `for (var i = ${indicesHelper.rank} - 1; i >= 0; i = i - 1) {
retval = retval + uniforms.${indicesHelper.name}_strides[i] * (outputIndices[i] % uniforms.${
indicesHelper.name}_shape[i]);
}` :
`retval = uniforms.${indicesHelper.name}_strides * (outputIndices[${
dimsOutput.length} - 1] % uniforms.${indicesHelper.name}_shape);`}
return retval;
}`;

broadcastImpl = `
${calcOffsetImpl(a)}
${calcOffsetImpl(b)}
`;
}

let assignment: string;
if (vectorize) {
Expand All @@ -71,8 +52,8 @@ const createBinaryOpProgramShader =
} else {
assignment = `
let outputIndices = ${output.offsetToIndices('global_idx * 4u')};
let offsetA = calcOffset${a.name}(outputIndices);
let offsetB = calcOffset${b.name}(outputIndices);
let offsetA = ${a.broadcastedIndicesToOffset('outputIndices', output)};
let offsetB = ${b.broadcastedIndicesToOffset('outputIndices', output)};
${
output.setByOffset(
'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))}
Expand All @@ -92,8 +73,8 @@ const createBinaryOpProgramShader =
const expressionB = `bData[indexB${x}][componentB${x}]`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = calcOffset${a.name}(outputIndices${x});
let offsetB${x} = calcOffset${b.name}(outputIndices${x});
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let componentA${x} = offsetA${x} % 4u;
Expand Down

0 comments on commit 47c2be5

Please sign in to comment.