Skip to content

Commit

Permalink
Revert "Pass vectorize flag to suppress generating unnecessary shapes…
Browse files Browse the repository at this point in the history
… and strides variable in the shader code."

This reverts commit 33d2283.
  • Loading branch information
satyajandhyala committed Nov 4, 2023
1 parent 72c5e1a commit a44e6ac
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ const createBinaryOpProgramShader =
const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA;
const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB;
const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput;
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4, vectorize && !doBroadcast);
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4, vectorize && !doBroadcast);
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4, vectorize && !doBroadcast);
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4);
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4);
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4);

let assignment: string;
if (vectorize) {
Expand Down
14 changes: 7 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ export const sumVector = (name: string, components: number) => {
* vec4.
*/
const createIndicesHelper =
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4,
vectorize = false): IndicesHelper => {
(name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean,
components: 1|2|3|4): IndicesHelper => {
const useUniform = typeof shapeOrRank === 'number';
const rank = useUniform ? shapeOrRank : shapeOrRank.length;
const rankIdentity = [...new Array(rank).keys()];
Expand Down Expand Up @@ -568,7 +568,7 @@ const createIndicesHelper =

const impl = () => {
const impls = [];
if (!useUniform && !vectorize) {
if (!useUniform) {
impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`);
impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`);
}
Expand Down Expand Up @@ -630,8 +630,8 @@ const createIndicesHelper =
* @returns an IndicesHelper for the input.
*/
export const inputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, vectorize = false):
IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, vectorize);
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, true, components);

/**
* Create a IndicesHelper for an output.
Expand All @@ -643,8 +643,8 @@ export const inputVariable =
* @returns an IndicesHelper for the output.
*/
export const outputVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, vectorize = false):
IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components, vectorize);
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, false, components);

/**
* A ShaderHelper is a helper class for generating WGSL code.
Expand Down

0 comments on commit a44e6ac

Please sign in to comment.