From e2a2efee3eba9063f8a9ba482b12de350b399255 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Wed, 13 Sep 2023 04:39:27 +0400 Subject: [PATCH] Gather fix --- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 34fbd681637fe..b3f63485366c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -32,18 +32,12 @@ const createGatherProgramInfo = const inputDataType = inputs[0].dataType; const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); - let elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; + const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; - // for f16 when block size is odd, we'll use single f16 - // when it's odd just one u32 let gatherType = DataType.uint32; if (inputDataType === DataType.float16) { - if (block % 2 === 0) { - elementSize = 2; - } else { - gatherType = DataType.float16; - } + gatherType = DataType.float16; } const blockSize = elementSize * block; const components = getMaxComponents(blockSize);