Skip to content

Commit

Permalink
broadcast to broadcasted
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Sep 27, 2023
1 parent 830582a commit c7c31c8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
24 changes: 12 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ export interface IndicesHelper {
*
* @returns an `u32` expression
*/
readonly broadcastIndicesToOffset: (varIndices: string, output: IndicesHelper) => string;
readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string;

/**
* WGSL code of generating an indices literal
Expand Down Expand Up @@ -272,7 +272,7 @@ const createIndicesHelper =
const implementationUsed = {
offsetToIndices: false,
indicesToOffset: false,
broadcastIndicesToOffset: false,
broadcastedIndicesToOffset: false,
set: false,
setByIndices: false,
get: false,
Expand Down Expand Up @@ -304,7 +304,7 @@ const createIndicesHelper =
return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
};

let offsets: string[] = [];
const offsets: string[] = [];
if (rank >= 2) {
for (let i = rank - 1; i >= 0; i--) {
offsets.push(`${strides[i]}u * (indices[${i}])`);
Expand All @@ -321,19 +321,19 @@ const createIndicesHelper =
return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
};

let broadcastIndicesToOffsetImplementation = '';
const broadcastIndicesToOffset = (varIndices: string, output: IndicesHelper) => {
let broadcastedIndicesToOffsetImplementation = '';
const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => {
const offsets = [];
for (let i = shape.length - 1; i >= 0; i--) {
const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length);
offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`);
}
broadcastIndicesToOffsetImplementation =
`fn ${output.name}broadcastIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 {
broadcastedIndicesToOffsetImplementation =
`fn ${output.name}broadcastedIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 {
return ${offsets.length > 0 ? offsets.join('+') : '0u'};
}`;
implementationUsed.broadcastIndicesToOffset = true;
return `${output.name}broadcastIndicesTo${name}Offset(${varIndices})`;
implementationUsed.broadcastedIndicesToOffset = true;
return `${output.name}broadcastedIndicesTo${name}Offset(${varIndices})`;
};

const indices = (...init: ReadonlyArray<number|string>) =>
Expand Down Expand Up @@ -488,8 +488,8 @@ const createIndicesHelper =
if (implementationUsed.indicesToOffset) {
impls.push(indicesToOffsetImplementation);
}
if (implementationUsed.broadcastIndicesToOffset) {
impls.push(broadcastIndicesToOffsetImplementation);
if (implementationUsed.broadcastedIndicesToOffset) {
impls.push(broadcastedIndicesToOffsetImplementation);
}
if (implementationUsed.set) {
impls.push(setImplementation);
Expand All @@ -511,7 +511,7 @@ const createIndicesHelper =
type,
offsetToIndices,
indicesToOffset,
broadcastIndicesToOffset,
broadcastedIndicesToOffset,
indices,
indicesGet,
indicesSet,
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ const createWhereOpProgramShader =
const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
return `
let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`, output)};
let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`, output)};
let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
let indexA${x} = offsetA${x} / 4u;
let indexB${x} = offsetB${x} / 4u;
let indexC${x} = offsetC${x} / 4u;
Expand Down

0 comments on commit c7c31c8

Please sign in to comment.