diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index d998013352d77..3dc4e957e0fee 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; @@ -44,34 +45,51 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => const inputShape = inputs[0].dims; const shape = Array.from(inputs[1].getBigInt64Array(), Number); const outputShape: number[] = calculateOutputShape(inputShape, shape); - const outputSize = ShapeUtil.size(outputShape); - const dataType = inputs[0].dataType; + const components = dataType === DataType.bool ? 4 : 1; + const outputSize = ShapeUtil.size(outputShape) / components; + const enableInputShapeUniform = enableShapesUniforms(inputShape.length); - const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; - const input = inputVariable('input', dataType, inputShapeOrRank); const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); - const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; - const output = outputVariable('output', dataType, outputShapeOrRank); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const inputShape = ${input.indices(...inputShape)}; - ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} - let outputIndices = ${output.offsetToIndices('global_idx')}; - var inputIndices: ${input.type.indices}; - for (var i = 0; i < ${inputShape.length}; i++) { - if (${input.indicesGet('inputShape', 'i')} == 1) { - ${input.indicesSet('inputIndices', 'i', 0)} - } else { - ${ - input.indicesSet( - 'inputIndices', 'i', output.indicesGet('outputIndices', `i + ${outputShape.length - inputShape.length}`))} - } + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; + const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; + const input = inputVariable('input', dataType, inputShapeOrRank, components); + const output = outputVariable('output', dataType, outputShapeOrRank, components); + let assignment: string; + if (dataType === DataType.bool) { + const singleAssignment = (resStr: string, x: number, typeCast = '') => ` + let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)}; + let offset${x} = ${input.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let index${x} = offset${x} / 4u; + let component${x} = offset${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${input.getByOffset(`index${x}`)}[component${x}]); + `; + assignment = ` + let outputOffset = global_idx * ${components}; + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + ${output.setByOffset('global_idx', 'data')} + }`; + } else { + assignment = ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)}; + ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))} + }`; } - ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))} - }`; + return ` + ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')} + ${assignment}`; + }; + const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; if (enableInputShapeUniform) { programUniforms.push(...createTensorShapeVariables(inputShape)); @@ -81,7 +99,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => } return { name: 'Expand', - shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, getShaderSource, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5d6d6debadb9a..53ca094abfd62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -29,7 +30,8 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath outputShape.splice(axis, 1, ...indicesShape); const axisDimLimit = inputShape[axis]; - const outputSize = ShapeUtil.size(outputShape); + const components = inputs[0].dataType === DataType.bool ? 4 : 1; + const outputSize = ShapeUtil.size(outputShape) / components; const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; @@ -38,10 +40,6 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; - const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank); - const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); - const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank); - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; if (enableInputShapesUniforms) { @@ -58,46 +56,75 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); - const calcDataIndices = (): string => { - const indicesRank = indicesShape.length; - let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; - for (let i = 0; i < indicesRank; i++) { - calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ - outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`; - } - calcStr += ` - var idx = ${indices.getByIndices('indicesIndices')}; - if (idx < 0) { - idx = idx + uniforms.axisDimLimit; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); + const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); + const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + + const calcDataIndices = (x: number|string): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices${x} = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices${x}[${i}]` : `indicesIndices${x}`} = ${ + outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}`};`; + } + calcStr += ` + var idx${x} = ${indices.getByIndices(`indicesIndices${x}`)}; + if (idx${x} < 0) { + idx${x} = idx${x} + uniforms.axisDimLimit; + } + var dataIndices${x} = ${data.type.indices}(0); + `; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = u32(idx${x});`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = ${ + outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}`};`; + j++; } - var dataIndices = ${data.type.indices}(0); - `; - for (let i = 0, j = 0; i < inputRank; i++) { - if (i === axis) { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; - j += indicesRank; - } else { - calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ - outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; - j++; } + return calcStr; + }; + let assignment: string; + if (inputs[0].dataType === DataType.bool) { + const singleAssignment = (resStr: string, x: number, typeCast = '') => ` + let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)}; + ${calcDataIndices(x)}; + let offset${x} = ${data.indicesToOffset(`dataIndices${x}`)}; + let index${x} = offset${x} / 4u; + let component${x} = offset${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${data.getByOffset(`index${x}`)}[component${x}]); + `; + assignment = ` + let outputOffset = global_idx * ${components}; + var value = vec4(0); + ${singleAssignment('value', 0, 'u32')} + ${singleAssignment('value', 1, 'u32')} + ${singleAssignment('value', 2, 'u32')} + ${singleAssignment('value', 3, 'u32')} + ${output.setByOffset('global_idx', 'value')} + `; + } else { + assignment = ` + let outputIndices = ${output.offsetToIndices('global_idx')}; + ${calcDataIndices('')}; + let value = ${data.getByIndices('dataIndices')}; + ${output.setByOffset('global_idx', 'value')}; + `; } - return calcStr; - }; - - const getShaderSource = (shaderHelper: ShaderHelper) => ` + return ` ${ - shaderHelper.registerUniform('outputSize', 'u32') - .registerUniform('axisDimLimit', 'i32') - .registerUniform('axis', 'u32') - .declareVariables(data, indices, output)} + shaderHelper.registerUniform('outputSize', 'u32') + .registerUniform('axisDimLimit', 'i32') + .registerUniform('axis', 'u32') + .declareVariables(data, indices, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} - let outputIndices = ${output.offsetToIndices('global_idx')}; - ${calcDataIndices()}; - let value = ${data.getByIndices('dataIndices')}; - ${output.setByOffset('global_idx', 'value')}; + ${assignment} }`; + }; return { name: 'Gather', shaderCache: {hint: attributes.cacheKey, inputDependencies}, diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc index 35888e2fc3709..22bc04d558d98 100644 --- a/js/web/test/data/ops/expand.jsonc +++ b/js/web/test/data/ops/expand.jsonc @@ -112,6 +112,79 @@ "type": "float32" } ] + }, + { + "name": "Expand 5 - shape < input.size()", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 1, 2, 6], + "type": "float32" + }, + { + "data": [2, 1, 6], + "dims": [3], + "type": "int64" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 2, 6], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Expand - bool", + "operator": "Expand", + "attributes": [], + "cases": [ + { + "name": "Expand - last dim is divisible by 4", + "inputs": [ + { + "data": [true, false, false, true], + "dims": [4], + "type": "bool" + }, + { + "data": [2, 4], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [true, false, false, true, true, false, false, true], + "dims": [2, 4], + "type": "bool" + } + ] + }, + { + "name": "Expand - last dim is not divisible by 4", + "inputs": [ + { + "data": [true, false, false, true, true, true, false, false, false, true, true, true], + "dims": [2, 6], + "type": "bool" + }, + { + "data": [2, 1], + "dims": [2], + "type": "int64" + } + ], + "outputs": [ + { + "data": [true, false, false, true, true, true, false, false, false, true, true, true], + "dims": [2, 6], + "type": "bool" + } + ] } ] } diff --git a/js/web/test/data/ops/gather.jsonc b/js/web/test/data/ops/gather.jsonc index 3b1b0e3821832..0be077d237b88 100644 --- a/js/web/test/data/ops/gather.jsonc +++ b/js/web/test/data/ops/gather.jsonc @@ -93,5 +93,34 @@ ] } ] + }, + { + "name": "Gather - bool", + "operator": "Gather", + "attributes": [], + "cases": [ + { + "name": "data[2,4] indices[1]", + "inputs": [ + { + "data": [true, false, false, true, false, false, true, true], + "dims": [2, 4], + "type": "bool" + }, + { + "data": [1], + "dims": [1], + "type": "int32" + } + ], + "outputs": [ + { + "data": [false, false, true, true], + "dims": [1, 4], + "type": "bool" + } + ] + } + ] } ] diff --git a/onnxruntime/core/providers/js/js_data_types.cc b/onnxruntime/core/providers/js/js_data_types.cc index 341d2cc19506f..cc56f55f26994 100644 --- a/onnxruntime/core/providers/js/js_data_types.cc +++ b/onnxruntime/core/providers/js/js_data_types.cc @@ -29,4 +29,4 @@ const std::vector& JsepSupportedFloatTypes() { } } // namespace js -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/expand.cc b/onnxruntime/core/providers/js/operators/expand.cc index 61d6511a3711a..76be1fd8797be 100644 --- a/onnxruntime/core/providers/js/operators/expand.cc +++ b/onnxruntime/core/providers/js/operators/expand.cc @@ -13,7 +13,11 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) .InputMemoryType(OrtMemTypeCPU, 1), Expand); @@ -23,7 +27,11 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) .InputMemoryType(OrtMemTypeCPU, 1), Expand); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc index e9c6f5c79294f..485cd3da9b91b 100644 --- a/onnxruntime/core/providers/js/operators/gather.cc +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -15,7 +15,11 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -26,7 +30,11 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -36,7 +44,11 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", JsepSupportedDataTypes()) + .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather);