diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 687ee054096cc..2ef9637bcda5e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -76,7 +76,6 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); - const vecSize = Math.ceil(outputSize / 4); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { @@ -88,6 +87,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputSize = ShapeUtil.size(outputShape); } + const vecSize = Math.ceil(outputSize / 4); + return { name: 'Where', shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},