Skip to content

Commit

Permalink
use uniforms for HardSigmoid attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Jan 26, 2024
1 parent f15dae9 commit d08ddca
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 43 deletions.
11 changes: 3 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
import {getActivationSnippet} from '../fuse-utils';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';

import {biasSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
Expand Down Expand Up @@ -193,10 +193,7 @@ export const createConv2DMatMulProgramInfo =
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
{type: 'int32', data: attributes.dilations}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
Expand All @@ -212,9 +209,7 @@ export const createConv2DMatMulProgramInfo =
{name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
{name: 'dilation', type: 'i32', length: 2}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
appendActivationUniforms(attributes, uniforms);

// TODO: support component 2, 3.
const components = isVec4 ? 4 : 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {getActivationSnippet} from '../fuse-utils';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';

import {biasSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
Expand Down Expand Up @@ -201,10 +201,7 @@ export const createConv2DTransposeMatMulProgramInfo =
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));

Expand Down Expand Up @@ -237,9 +234,7 @@ export const createConv2DTransposeMatMulProgramInfo =
{name: 'filter_dims', type: 'i32', length: filterDims.length},
{name: 'pads', type: 'i32', length: pads.length}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
appendActivationUniforms(attributes, uniforms);
return `
${utilFunctions('uniforms.result_strides')}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';

import {typeSnippet} from './activation_util';

Expand Down Expand Up @@ -449,11 +449,7 @@ export const createMatmulProgramInfo =
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
const programUniforms: ProgramUniform[] =
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
if (activationAttributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: activationAttributes.clipMax!},
{type: 'float32', data: activationAttributes.clipMin!});
}
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
...createTensorShapeVariables(bShapeTemp));
Expand Down Expand Up @@ -481,9 +477,7 @@ export const createMatmulProgramInfo =
}
const uniforms: UniformsArrayType =
[{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}];
if (activationAttributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
appendActivationUniforms(activationAttributes, uniforms);
const applyActivation = getActivationSnippet(activationAttributes, output.type.value);
const declareFunctions = matMulReadWriteFnSource(
components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims],
Expand Down
11 changes: 3 additions & 8 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../

import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
import {calculateOutputShape, ConvAttributes} from './conv';
import {getActivationSnippet} from './fuse-utils';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils';

/**
* naive grouped conv implementation, supports 1d/2d conv
Expand All @@ -32,10 +32,7 @@ export const createGroupedConvProgramInfo =
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
];
if (attributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
}
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
...createTensorShapeVariables(outputShape));
Expand All @@ -61,9 +58,7 @@ export const createGroupedConvProgramInfo =
{name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2},
{name: 'output_channels_per_group', type: 'u32'}
];
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
appendActivationUniforms(attributes, uniforms);
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)}
Expand Down
30 changes: 29 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
// Licensed under the MIT License.

import {MAX_CLIP, MIN_CLIP} from '../../util';
import {ProgramUniform} from '../types';

import {UniformsArrayType} from './common';

export interface InternalActivationAttributes {
readonly activation: string;
readonly clipMin?: number;
readonly clipMax?: number;
readonly alpha?: number;
readonly beta?: number;
}

export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => {
Expand All @@ -17,16 +22,39 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v
return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`;
case 'Clip':
return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`;
case 'HardSigmoid':
return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${
valueType}(uniforms.beta)));`;
// TODO: adding other activations that can be fused.
default:
return '';
}
};

export const appendActivationUniformsData =
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
if (attributes.activation === 'Clip') {
programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
} else if (attributes.activation === 'HardSigmoid') {
programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!});
}
};

export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => {
if (attributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
} else if (attributes.activation === 'HardSigmoid') {
uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'});
}
};

export const parseInternalActivationAttributes =
(attributes: Record<string, unknown>|undefined): InternalActivationAttributes => {
const activation = attributes?.activation as string || '';

if (activation === 'HardSigmoid') {
const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5];
return {activation, alpha, beta};
}
if (activation === 'Clip') {
const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP];
return {activation, clipMax, clipMin};
Expand Down
12 changes: 3 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common';
import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils';

export const createNaiveMatmulProgramInfo =
(inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
Expand All @@ -32,11 +32,7 @@ export const createNaiveMatmulProgramInfo =
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
{type: 'uint32', data: K}
];
if (activationAttributes.activation === 'Clip') {
programUniforms.push(
{type: 'float32', data: activationAttributes.clipMax!},
{type: 'float32', data: activationAttributes.clipMin!});
}
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
...createTensorShapeVariables(bShape));
Expand Down Expand Up @@ -69,9 +65,7 @@ export const createNaiveMatmulProgramInfo =
{name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'},
{name: 'K', type: 'u32'}
];
if (activationAttributes.activation === 'Clip') {
uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'});
}
appendActivationUniforms(activationAttributes, uniforms);

const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
const rank = variable.rank;
Expand Down

0 comments on commit d08ddca

Please sign in to comment.