From c7c31c887d007092b2a4d09d64790c005553105b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 27 Sep 2023 09:10:50 +0800 Subject: [PATCH] broadcast to broadcasted --- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 24 +++++++++++------------ js/web/lib/wasm/jsep/webgpu/ops/where.ts | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 49e9217a2d41f..729ee78f282b2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -110,7 +110,7 @@ export interface IndicesHelper { * * @returns an `u32` expression */ - readonly broadcastIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; + readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; /** * WGSL code of generating an indices literal @@ -272,7 +272,7 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, - broadcastIndicesToOffset: false, + broadcastedIndicesToOffset: false, set: false, setByIndices: false, get: false, @@ -304,7 +304,7 @@ const createIndicesHelper = return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; }; - let offsets: string[] = []; + const offsets: string[] = []; if (rank >= 2) { for (let i = rank - 1; i >= 0; i--) { offsets.push(`${strides[i]}u * (indices[${i}])`); @@ -321,19 +321,19 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; - let broadcastIndicesToOffsetImplementation = ''; - const broadcastIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + let broadcastedIndicesToOffsetImplementation = ''; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { const offsets = []; for (let i = shape.length - 1; i >= 0; i--) { const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); } - broadcastIndicesToOffsetImplementation = - `fn ${output.name}broadcastIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { + broadcastedIndicesToOffsetImplementation = + `fn ${output.name}broadcastedIndicesTo${name}Offset(outputIndices: ${output.type.indices}) -> u32 { return ${offsets.length > 0 ? offsets.join('+') : '0u'}; }`; - implementationUsed.broadcastIndicesToOffset = true; - return `${output.name}broadcastIndicesTo${name}Offset(${varIndices})`; + implementationUsed.broadcastedIndicesToOffset = true; + return `${output.name}broadcastedIndicesTo${name}Offset(${varIndices})`; }; const indices = (...init: ReadonlyArray) => @@ -488,8 +488,8 @@ const createIndicesHelper = if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } - if (implementationUsed.broadcastIndicesToOffset) { - impls.push(broadcastIndicesToOffsetImplementation); + if (implementationUsed.broadcastedIndicesToOffset) { + impls.push(broadcastedIndicesToOffsetImplementation); } if (implementationUsed.set) { impls.push(setImplementation); @@ -511,7 +511,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, - broadcastIndicesToOffset, + broadcastedIndicesToOffset, indices, indicesGet, indicesSet, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 5543254e903c7..4c595bb90b4bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -33,9 +33,9 @@ const createWhereOpProgramShader = const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; return ` let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; - let offsetA${x} = ${a.broadcastIndicesToOffset(`outputIndices${x}`, output)}; - let offsetB${x} = ${b.broadcastIndicesToOffset(`outputIndices${x}`, output)}; - let offsetC${x} = ${c.broadcastIndicesToOffset(`outputIndices${x}`, output)}; + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; let indexA${x} = offsetA${x} / 4u; let indexB${x} = offsetB${x} / 4u; let indexC${x} = offsetC${x} / 4u;