diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index e09bafe0afb16..9dc6c160710e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -33,9 +33,9 @@ const createBinaryOpProgramShader = const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; - const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4, vectorize && !doBroadcast); - const a = inputVariable('aData', typeA, inputAShapeOrRank, 4, vectorize && !doBroadcast); - const b = inputVariable('bData', typeB, inputBShapeOrRank, 4, vectorize && !doBroadcast); + const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); + const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); + const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); let assignment: string; if (vectorize) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 19093b9254bab..1d3fc78fe368a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -335,8 +335,8 @@ export const sumVector = (name: string, components: number) => { * vec4. */ const createIndicesHelper = - (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, components: 1|2|3|4, - vectorize = false): IndicesHelper => { + (name: string, tensorType: number, shapeOrRank: number|readonly number[], isInput: boolean, + components: 1|2|3|4): IndicesHelper => { const useUniform = typeof shapeOrRank === 'number'; const rank = useUniform ? shapeOrRank : shapeOrRank.length; const rankIdentity = [...new Array(rank).keys()]; @@ -568,7 +568,7 @@ const createIndicesHelper = const impl = () => { const impls = []; - if (!useUniform && !vectorize) { + if (!useUniform) { impls.push(`const ${shape} = ${type.indices}(${shapeOrRank.join(',')});`); impls.push(`const ${strides} = ${type.indices}(${ShapeUtil.computeStrides(shapeOrRank).join(',')});`); } @@ -630,8 +630,8 @@ const createIndicesHelper = * @returns an IndicesHelper for the input. */ export const inputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, vectorize = false): - IndicesHelper => createIndicesHelper(name, type, shapeOrRank, true, components, vectorize); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, true, components); /** * Create a IndicesHelper for an output. @@ -643,8 +643,8 @@ export const inputVariable = * @returns an IndicesHelper for the output. */ export const outputVariable = - (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1, vectorize = false): - IndicesHelper => createIndicesHelper(name, type, shapeOrRank, false, components, vectorize); + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, false, components); /** * A ShaderHelper is a helper class for generating WGSL code.