Skip to content

Commit

Permalink
[js/webgpu] Remove the limitation on axis in softmax (#22231)
Browse files Browse the repository at this point in the history
In current implementation, axis in softmax has to be the last, which is
an obvious limitation. This PR removes this limitation and will fix
issues #20710 and #22176.
  • Loading branch information
Yang Gu authored Oct 1, 2024
1 parent d9de054 commit c75f4a0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 33 deletions.
76 changes: 51 additions & 25 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo } from '../types';
import { ComputeContext } from '../types';
import { createTransposeProgramInfo } from './transpose';

import {
getMaxComponents,
Expand All @@ -30,19 +31,32 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey {
readonly axis: number;
}

const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAttributes) => {
const input = context.inputs[0];
const inputShape = input.dims;
const outputSize = ShapeUtil.size(inputShape);
const WG = 64;
let axis = attributes.axis;
if (axis < 0) {
axis = shape.length + axis;
}
if (axis < shape.length - 1) {
throw new Error('softmax only supports last axis for now.');
const inputRank = inputShape.length;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const isTransposeRequired = axis < inputShape.length - 1;
let transposedInput: TensorView;
let perm: number[] = [];

if (isTransposeRequired) {
perm = Array.from({ length: inputRank }, (_, i) => i);
perm[axis] = inputRank - 1;
perm[inputRank - 1] = axis;

transposedInput = context.compute(createTransposeProgramInfo(input, perm), {
inputs: [input],
outputs: [-1],
})[0];
} else {
transposedInput = input;
}

const cols = shape[axis];
const transposedInputShape = transposedInput.dims;
const cols = transposedInputShape[inputRank - 1];
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;
Expand All @@ -58,12 +72,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut

return name;
};
const x = inputVariable('x', input.dataType, input.dims, components);
const output = outputVariable('result', input.dataType, input.dims, components);
const x = inputVariable('x', transposedInput.dataType, transposedInput.dims, components);
const output = outputVariable('result', transposedInput.dataType, transposedInput.dims, components);
const valueType = x.type.value;
// 6.2.4 in wgsl spec
const threadMaxDecl =
tensorTypeToWsglStorageType(input.dataType) === 'f32'
tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32'
? `var threadMax = ${valueType}(-3.402823e+38f);`
: `var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
Expand Down Expand Up @@ -139,21 +153,33 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
setValue(row, col, row_stride, value);
}
}`;
return {
name: 'Softmax',
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
getRunData: () => ({
outputs: [{ dims: shape, dataType: input.dataType }],
dispatchGroup: { x: rows },
programUniforms: [{ type: DataType.int32, data: packedCols }],
}),
getShaderSource,
};
const result = context.compute(
{
name: 'Softmax',
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
getRunData: () => ({
outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }],
dispatchGroup: { x: rows },
programUniforms: [{ type: DataType.int32, data: packedCols }],
}),
getShaderSource,
},
{
inputs: [transposedInput],
outputs: [isTransposeRequired ? -1 : 0],
},
)[0];

if (isTransposeRequired) {
context.compute(createTransposeProgramInfo(result, perm), {
inputs: [result],
});
}
};

export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
validateInputs(context.inputs);
context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
createSoftmaxProgramInfo(context, attributes);
};

export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
Expand Down
67 changes: 59 additions & 8 deletions js/web/test/data/ops/softmax.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@
"type": "float32"
}
]
}
]
},
{
"name": "Softmax with no attributes",
"operator": "Softmax",
"attributes": [],
"cases": [
},
{
"name": "T[2, 2, 2]",
"inputs": [
Expand All @@ -49,5 +42,63 @@
]
}
]
},
{
"name": "Softmax with attribute axis -1",
"operator": "Softmax",
"attributes": [{ "name": "axis", "data": -1, "type": "int" }],
"cases": [
{
"name": "T[2,2]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0],
"dims": [2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "Softmax with attribute axis 1",
"operator": "Softmax",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
20.0, 21.0, 22.0, 23.0, 24.0
],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
0.9999938011169434, 0.9999938011169434
],
"dims": [1, 2, 3, 4],
"type": "float32"
}
]
}
]
}
]

0 comments on commit c75f4a0

Please sign in to comment.