Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into snnn-patch-12
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jan 24, 2024
2 parents edcbcb5 + a33b5bd commit 8151506
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 70 deletions.
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import * as pool from './ops/pool';
import {range} from './ops/range';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
import {skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
Expand Down Expand Up @@ -116,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
['Slice', [slice, parseSliceAttributes]],
['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]],
['SkipLayerNormalization', [skipLayerNorm]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Softmax', [softmax, parseSoftmaxAttributes]],
Expand Down
123 changes: 67 additions & 56 deletions js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';

import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common';
import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common';

export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
epsilon: number;
Expand Down Expand Up @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo =
const hasInputSkipBiasSumOutput = outputCount > 3;

const components = getMaxComponents(hiddenSize);
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const hiddenSize: f32 = ${hiddenSize};
const hiddenSizeVectorized: u32 = ${hiddenSize / components};
const epsilon: f32 = ${attributes.epsilon};

${shaderHelper.declareVariables(...variables)}
const programUniforms: ProgramUniform[] = [
{type: 'uint32', data: outputSize},
{type: 'uint32', data: components},
{type: 'uint32', data: hiddenSize},
{type: 'float32', data: attributes.epsilon},
];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const uniformsArray: UniformsArrayType = [
{name: 'output_size', type: 'u32'},
{name: 'components', type: 'u32'},
{name: 'hidden_size', type: 'u32'},
{name: 'epsilon', type: 'f32'},
];
const variables = [
inputVariable('x', inputs[0].dataType, inputs[0].dims, components),
inputVariable('skip', inputs[1].dataType, inputs[1].dims, components),
inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components),
];
if (hasBetaInput) {
variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components));
}
if (hasBiasInput) {
variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components));
}
variables.push(outputVariable('output', inputs[0].dataType, outputShape, components));
if (hasMeanOutput) {
variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim));
}
if (hasInvStdDevOutput) {
variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim));
}
if (hasInputSkipBiasSumOutput) {
variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components));
}
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
return `
${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)}
let offset = global_idx * hiddenSizeVectorized;
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')}
let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;
let offset = global_idx * hidden_size_vectorized;
var sum = ${fillVector('f32', components)};
var squareSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
let skipValue = skip[offset + i];
let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'};
let inputValue = x[offset + i];
let value = inputValue + skipValue + biasValue;
${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
let skip_value = skip[offset + i];
let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'};
let input_value = x[offset + i];
let value = input_value + skip_value + bias_value;
${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''}
output[offset + i] = value;
let f32Value = ${castToF32(dataType, components, 'value')};
sum += f32Value;
squareSum += f32Value * f32Value;
let f32_value = ${castToF32(dataType, components, 'value')};
sum += f32_value;
squareSum += f32_value * f32_value;
}
let mean = ${sumVector('sum', components)} / hiddenSize;
let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon);
${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''}
for (var i: u32 = 0; i < hiddenSizeVectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i]
+ ${hasBetaInput ? 'beta[i]' : '0.0'};
let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size);
let inv_std_dev = inverseSqrt(${
sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon);
${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''}
${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''}
for (var i: u32 = 0; i < hidden_size_vectorized; i++) {
output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${
hasBetaInput ? 'beta[i]' : '0.0'};
}
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
if (outputCount > 1) {
outputs.push({dims: meanInvStdDevDim, dataType: DataType.float});
Expand All @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo =
if (outputCount > 3) {
outputs.push({dims: inputShape, dataType: inputs[0].dataType});
}

return {
name: 'SkipLayerNormalization',
shaderCache: {hint: attributes.cacheKey},
shaderCache: {
hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`,
inputDependencies: inputs.map((_input, _index) => 'type')
},
getShaderSource,
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}),
getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}),
};
};

Expand All @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm
context.compute(
createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs});
};

export const parseSkipLayerNormAttributes = (attributes: Record<string, unknown>): SkipLayerNormAttributes => {
const epsilon = attributes.epsilon as number;
return createAttributeWithCacheKey({epsilon});
};
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class DmlOperatorPadding : public DmlOperator, public PaddingHelper
{
mode = DML_PADDING_MODE_REFLECTION;
}
#if DML_TARGET_VERSION >= 0x6300
else if (modeString == AttrValue::Wrap)
{
mode = DML_PADDING_MODE_WRAP;
}
#endif
else
{
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
Expand Down Expand Up @@ -116,5 +122,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel<DmlOperatorPadding, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel<DmlOperatorPadding, 11>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel<DmlOperatorPadding, 13>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel<DmlOperatorPadding, 18>);
DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel<DmlOperatorPadding, 19>);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad7);
DML_OP_EXTERN_CREATION_FUNCTION(Pad11);
DML_OP_EXTERN_CREATION_FUNCTION(Pad13);
DML_OP_EXTERN_CREATION_FUNCTION(Pad18);
DML_OP_EXTERN_CREATION_FUNCTION(Pad19);
DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth);
DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace);
DML_OP_EXTERN_CREATION_FUNCTION(Sqrt);
Expand Down Expand Up @@ -747,6 +748,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},

#if DML_TARGET_VERSION >= 0x6300
{REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)},
#endif

{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,6 @@ namespace AttrValue
static constexpr const char* NearestNeighbor = "NN";
static constexpr const char* NotSet = "NOTSET";
static constexpr const char* Reflect = "reflect";
static constexpr const char* Wrap = "wrap";

} // namespace AttrValue
Original file line number Diff line number Diff line change
Expand Up @@ -1589,6 +1589,7 @@ using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper<PaddingHelper, 7>;
using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper<PaddingHelper, 11>;
using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper<PaddingHelper, 13>;
using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper<PaddingHelper, 18>;
using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper<PaddingHelper, 19>;

using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper;
using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ namespace OperatorHelper
namespace OnnxOperatorSet19
{
static const int sc_sinceVer_AveragePool = 19;
static const int sc_sinceVer_Pad = 19;
static const int sc_sinceVer_Cast = 19;
static const int sc_sinceVer_CastLike = 19;
static const int sc_sinceVer_Constant = 19;
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi

session_options = self._sess_options if self._sess_options else C.get_default_session_options()

self._register_ep_custom_ops(session_options, providers, provider_options)
self._register_ep_custom_ops(session_options, providers, provider_options, available_providers)

if self._model_path:
sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
Expand Down Expand Up @@ -510,11 +510,15 @@ def _reset_session(self, providers, provider_options):
self._sess_options = self._sess_options_initial
self._create_inference_session(providers, provider_options)

def _register_ep_custom_ops(self, session_options, providers, provider_options):
def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers):
for i in range(len(providers)):
if providers[i] == "TensorrtExecutionProvider":
if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider":
C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i])
elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider":
elif (
isinstance(providers[i], tuple)
and providers[i][0] in available_providers
and providers[i][0] == "TensorrtExecutionProvider"
):
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -1034,7 +1034,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -1046,7 +1046,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
Expand All @@ -1055,7 +1055,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
MoreSuffix: '_Linux'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ stages:
# Testing
- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -162,7 +162,7 @@ stages:

- template: nuget/templates/test_win.yml
parameters:
AgentPool : 'onnxruntime-Win2022-GPU-T4'
AgentPool : 'onnxruntime-Win2022-GPU-A10'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows'
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
Expand All @@ -174,7 +174,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu'
Expand All @@ -184,7 +184,7 @@ stages:

- template: nuget/templates/test_linux.yml
parameters:
AgentPool : Onnxruntime-Linux-GPU
AgentPool : Onnxruntime-Linux-GPU-A10
ArtifactSuffix: 'GPU'
StageSuffix: 'GPU'
MoreSuffix: '_Linux'
Expand Down

0 comments on commit 8151506

Please sign in to comment.