Skip to content

Commit

Permalink
[js/webgpu] Optimize softmax by vector (#18153)
Browse files Browse the repository at this point in the history
### Description
This PR enables `softmax` outputs max supported components instead of
scalar for each thread.

Softmax with input[0]: [12,4096,4096] becomes 47.86 ms from 55.11 ms
  • Loading branch information
qjia7 authored Oct 30, 2023
1 parent 90d1f53 commit 785e2b1
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';

import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';

const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
Expand All @@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut

const cols = shape[axis];
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;
const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`;

const maxVector = (name: string, components: number) => {
if (components === 4) {
return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
} else if (components === 2) {
return `max(${name}.x, ${name}.y)`;
} else if (components === 3) {
return `max(max(${name}.x, ${name}.y), ${name}.z)`;
}

return name;
};

// 6.2.4 in wgsl spec
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
const threadMaxDecl =
dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
var<workgroup> rowMaxShared : ${valueType};
var<workgroup> rowSumShared : ${valueType};
var<workgroup> threadShared : array<${valueType}, ${WG}>;
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
@group(0) @binding(0) var<storage, read> x : array<${valueType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${valueType}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
let index = row * row_stride + col;
result[index] = value;
}
Expand All @@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
let lindex = i32(local_id.x);
const wg = ${WG};
let row = gindex / wg;
let cols = ${cols};
let row_stride : i32 = ${cols};
let cols = ${packedCols};
let row_stride : i32 = ${packedCols};
// find the rows max
${threadMaxDecl}
Expand All @@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = threadShared[0];
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
}
workgroupBarrier();
// find the rows sum
var threadSum: ${dataType} = 0.0;
var threadSum = ${valueType}(0.0);
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
Expand All @@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = threadShared[0];
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
}
workgroupBarrier();
Expand Down

0 comments on commit 785e2b1

Please sign in to comment.