Skip to content

Commit

Permalink
[js/webgpu] Support uniform for softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 8, 2023
1 parent 3dece27 commit d282082
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 28 deletions.
57 changes: 30 additions & 27 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';

import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
Expand All @@ -23,7 +23,6 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey {
}

const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const dataType = tensorTypeToWsglStorageType(input.dataType);
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const WG = 64;
Expand All @@ -39,7 +38,6 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;
const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`;

const maxVector = (name: string, components: number) => {
if (components === 4) {
Expand All @@ -52,36 +50,41 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut

return name;
};

const x = inputVariable('x', input.dataType, input.dims, components);
const output = outputVariable('result', input.dataType, input.dims, components);
const enableInputShapesUniforms = enableShapesUniforms(input.dims.length);
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: packedCols}];
if (enableInputShapesUniforms) {
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(input.dims));
}
const inputDependencies: ProgramInputTensorInfoDependency[] = [enableInputShapesUniforms ? 'rank' : 'dims'];
const dataType = x.type.value;
// 6.2.4 in wgsl spec
const threadMaxDecl =
dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${valueType};
var<workgroup> rowSumShared : ${valueType};
var<workgroup> threadShared : array<${valueType}, ${WG}>;
@group(0) @binding(0) var<storage, read> x : array<${valueType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${valueType}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
const threadMaxDecl = tensorTypeToWsglStorageType(input.dataType) === 'f32' ?
`var threadMax = ${dataType}(-3.402823e+38f);` :
`var threadMax = ${dataType}(-65504.0h);`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
let index = row * row_stride + col;
result[index] = value;
}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(global_invocation_id) global_id : vec3u) {
${shaderHelper.registerUniform('packedCols', 'i32').declareVariables(x, output)}
${shaderHelper.mainStart()}
let gindex = i32(global_id.x);
let lindex = i32(local_id.x);
const wg = ${WG};
let row = gindex / wg;
let cols = ${packedCols};
let row_stride : i32 = ${packedCols};
let cols = uniforms.packedCols;
let row_stride : i32 = uniforms.packedCols;
// find the rows max
${threadMaxDecl}
Expand All @@ -103,12 +106,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
rowMaxShared = ${dataType}(${maxVector('threadShared[0]', components)});
}
workgroupBarrier();
// find the rows sum
var threadSum = ${valueType}(0.0);
var threadSum = ${dataType}(0.0);
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
Expand All @@ -123,7 +126,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
rowSumShared = ${dataType}(${sumVector('threadShared[0]', components)});
}
workgroupBarrier();
Expand All @@ -135,12 +138,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
}`;
return {
name: 'Softmax',
getRunData: () => ({outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}}),
shaderCache: {hint: `${components}`, inputDependencies},
getRunData: () => ({outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, programUniforms}),
getShaderSource,
};
};


export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
validateInputs(context.inputs);
context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
Expand Down
29 changes: 28 additions & 1 deletion js/web/test/data/ops/softmax.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"attributes": [],
"cases": [
{
"name": "T[2,4]",
"name": "T[2,2]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0],
Expand All @@ -22,5 +22,32 @@
]
}
]
},
{
"name": "Softmax with no attributes",
"operator": "Softmax",
"attributes": [],
"cases": [
{
"name": "T[2, 2, 2]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
"dims": [2, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973, 0.2689414322376251,
0.7310585975646973, 0.2689414322376251, 0.7310585975646973
],
"dims": [2, 2, 2],
"type": "float32"
}
]
}
]
}
]

0 comments on commit d282082

Please sign in to comment.