From d0bac8216d324e78e6f3a89453bac207522b93eb Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 11 Jan 2024 12:13:24 -0800 Subject: [PATCH] [js/webgpu] fix bcast in where (#19009) --- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 687ee054096cc..2ef9637bcda5e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -76,7 +76,6 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); let outputShape = dimsA; let outputSize = ShapeUtil.size(dimsA); - const vecSize = Math.ceil(outputSize / 4); // TODO: deal with zero-sized tensors (eg. dims=[1,0]) if (isBroadcast) { @@ -88,6 +87,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputSize = ShapeUtil.size(outputShape); } + const vecSize = Math.ceil(outputSize / 4); + return { name: 'Where', shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},