From bb43a0f1338b05e93fcbbe5c5cb53ebf017625ba Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 23 Feb 2024 15:45:30 -0800 Subject: [PATCH] [js/webgpu] minor fixes to make tinyllama work (#19564) --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 4 +++- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b06c9fb496d15..b142a82e551a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5c31e6dd86c00..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) {