From dfbb39d462b169e5c3a1984e95df9b018120933b Mon Sep 17 00:00:00 2001 From: guschmue Date: Mon, 19 Feb 2024 08:18:43 -0800 Subject: [PATCH 1/3] minor fixes to make tinyllama work --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 7 +++++++ js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b06c9fb496d15..41115aec57c62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -154,6 +154,13 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); + // 0 length tensors are valid for concat, remove them + for (let i = 0; i < context.inputs.length; i++) { + const size = ShapeUtil.size(context.inputs[i].dims); + if (size === 0) { + context.inputs.slice(i, 1); + } + } context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5c31e6dd86c00..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..dfc629ed62450 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -107,7 +107,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split }`; return { name: 'Split', - shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, + shaderCache: {hint: `${attributes.cacheKey};${inputShape}`, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, From 79f44bee82714592aa3b047218096fcb28cccba3 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 20 Feb 2024 09:50:47 -0800 Subject: [PATCH 2/3] remove split.ts from PR --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 701706ddcdaca..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -107,7 +107,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split }`; return { name: 'Split', - shaderCache: {hint: `${attributes.cacheKey};${inputShape}`, inputDependencies: ['rank']}, + shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']}, getShaderSource, getRunData: () => ({ outputs: outputsTensorInfo, From 68ca86ea36cd8d737923e21cdb04268dcaea2173 Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 23 Feb 2024 08:41:13 -0800 Subject: [PATCH 3/3] review feedback --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 41115aec57c62..b142a82e551a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -155,13 +155,8 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); // 0 length tensors are valid for concat, remove them - for (let i = 0; i < context.inputs.length; i++) { - const size = ShapeUtil.size(context.inputs[i].dims); - if (size === 0) { - context.inputs.slice(i, 1); - } - } - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes =>