diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index c4e5a94f225da..fbab44e211946 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -9,7 +9,8 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo } from '../types'; +import { ComputeContext } from '../types'; +import { createTransposeProgramInfo } from './transpose'; import { getMaxComponents, @@ -30,19 +31,32 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey { readonly axis: number; } -const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => { - const shape = input.dims; - const outputSize = ShapeUtil.size(shape); +const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAttributes) => { + const input = context.inputs[0]; + const inputShape = input.dims; + const outputSize = ShapeUtil.size(inputShape); const WG = 64; - let axis = attributes.axis; - if (axis < 0) { - axis = shape.length + axis; - } - if (axis < shape.length - 1) { - throw new Error('softmax only supports last axis for now.'); + const inputRank = inputShape.length; + const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank); + const isTransposeRequired = axis < inputShape.length - 1; + let transposedInput: TensorView; + let perm: number[] = []; + + if (isTransposeRequired) { + perm = Array.from({ length: inputRank }, (_, i) => i); + perm[axis] = inputRank - 1; + perm[inputRank - 1] = axis; + + transposedInput = context.compute(createTransposeProgramInfo(input, perm), { + inputs: [input], + outputs: [-1], + })[0]; + } else { + transposedInput = input; } - const cols = shape[axis]; + const transposedInputShape = transposedInput.dims; + const cols = transposedInputShape[inputRank - 1]; const rows = outputSize / cols; const components = getMaxComponents(cols); const packedCols = cols / components; @@ -58,12 +72,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut return name; }; - const x = inputVariable('x', input.dataType, input.dims, components); - const output = outputVariable('result', input.dataType, input.dims, components); + const x = inputVariable('x', transposedInput.dataType, transposedInput.dims, components); + const output = outputVariable('result', transposedInput.dataType, transposedInput.dims, components); const valueType = x.type.value; // 6.2.4 in wgsl spec const threadMaxDecl = - tensorTypeToWsglStorageType(input.dataType) === 'f32' + tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`; const getShaderSource = (shaderHelper: ShaderHelper) => ` @@ -139,21 +153,33 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut setValue(row, col, row_stride, value); } }`; - return { - name: 'Softmax', - shaderCache: { hint: `${components}`, inputDependencies: ['type'] }, - getRunData: () => ({ - outputs: [{ dims: shape, dataType: input.dataType }], - dispatchGroup: { x: rows }, - programUniforms: [{ type: DataType.int32, data: packedCols }], - }), - getShaderSource, - }; + const result = context.compute( + { + name: 'Softmax', + shaderCache: { hint: `${components}`, inputDependencies: ['type'] }, + getRunData: () => ({ + outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }], + dispatchGroup: { x: rows }, + programUniforms: [{ type: DataType.int32, data: packedCols }], + }), + getShaderSource, + }, + { + inputs: [transposedInput], + outputs: [isTransposeRequired ? -1 : 0], + }, + )[0]; + + if (isTransposeRequired) { + context.compute(createTransposeProgramInfo(result, perm), { + inputs: [result], + }); + } }; export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => { validateInputs(context.inputs); - context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes)); + createSoftmaxProgramInfo(context, attributes); }; export const parseSoftmaxAttributes = (attributes: Record): SoftmaxAttributes => diff --git a/js/web/test/data/ops/softmax.jsonc b/js/web/test/data/ops/softmax.jsonc index 98573fcd73ba2..7fbe2119a0953 100644 --- a/js/web/test/data/ops/softmax.jsonc +++ b/js/web/test/data/ops/softmax.jsonc @@ -20,14 +20,7 @@ "type": "float32" } ] - } - ] - }, - { - "name": "Softmax with no attributes", - "operator": "Softmax", - "attributes": [], - "cases": [ + }, { "name": "T[2, 2, 2]", "inputs": [ @@ -49,5 +42,63 @@ ] } ] + }, + { + "name": "Softmax with attribute axis -1", + "operator": "Softmax", + "attributes": [{ "name": "axis", "data": -1, "type": "int" }], + "cases": [ + { + "name": "T[2,2]", + "inputs": [ + { + "data": [1.0, 2.0, 3.0, 4.0], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973], + "dims": [2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Softmax with attribute axis 1", + "operator": "Softmax", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[1, 2, 3, 4]", + "inputs": [ + { + "data": [ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, + 20.0, 21.0, 22.0, 23.0, 24.0 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, + 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, + 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, + 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, + 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, + 0.9999938011169434, 0.9999938011169434 + ], + "dims": [1, 2, 3, 4], + "type": "float32" + } + ] + } + ] } ]