Skip to content

Commit

Permalink
[JS/Web]Added FusedConv. (microsoft#17766)
Browse files Browse the repository at this point in the history
### Description
Added FusedConv and FusedConvTranspose



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Improve performance
  • Loading branch information
satyajandhyala authored Nov 1, 2023
1 parent 9e8ad39 commit a2e9ba7
Show file tree
Hide file tree
Showing 21 changed files with 339 additions and 98 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Do not modify directly.*
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| Gelu | com.microsoft(1+) | |
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['Gelu', [unaryOps.gelu]],
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ export const activationFnSnippet =
if (!activation) {
return '';
}

// TODO: add implementations
return '';
};

export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => `
${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''}
${activation ? 'value = activation(value, coords);' : ''}
// TODO uncomment the following line when activation is supported above.
// ${activation ? 'value = activation(value, coords);' : ''}
`;
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo =
${declareFunctions}
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0],
elementsSize[1], elementsSize[2], t)}
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias,
attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1],
elementsSize[2], t)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo =
const dimBOuter : i32 = ${dimBOuter};
const dimInner : i32 = ${dimInner};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)}
${
conv2dTransposeCommonSnippet(
isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)}
${
isVec4 ? makeMatMulPackedVec4Source(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view';
import {ShapeUtil} from '../../../util';
import {ProgramInfo} from '../../types';
import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils';
import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils';

import {typeSnippet} from './activation_util';

Expand Down Expand Up @@ -440,7 +440,7 @@ export const createMatmulProgramInfo =
const dimInner = aShape[aShape.length - 1];
const dimBOuter = bShape[bShape.length - 1];
const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0;
const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes);
const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4);

// TODO: fine tune size
const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1];
Expand Down Expand Up @@ -473,8 +473,8 @@ export const createMatmulProgramInfo =
const dimBOuter: i32 = ${dimBOuter};
const dimInner: i32 = ${dimInner};
${shaderHelper.declareVariables(...inputVariables, output)}
${declareFunctions}
${activationFunction}
${declareFunctions}
${
isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) :
makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)}
Expand Down
4 changes: 2 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types';

import {inputVariable, outputVariable, ShaderHelper} from './common';
import {calculateOutputShape, ConvAttributes} from './conv';
import {getActicationSnippet} from './fuse-utils';
import {getActivationSnippet} from './fuse-utils';

/**
* naive grouped conv implementation, supports 1d/2d conv
Expand All @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo =
const wShape = inputs[1].dims;
const outputChannelsPerGroup = wShape[0] / attributes.group;

const {activationFunction, applyActivation} = getActicationSnippet(attributes);
const {activationFunction, applyActivation} = getActivationSnippet(attributes);

const isChannelLast = attributes.format === 'NHWC';
const outputShape = calculateOutputShape(
Expand Down
37 changes: 19 additions & 18 deletions js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,25 @@ export interface InternalActivationAttributes {
readonly activationCacheKey: string;
}

export const getActicationSnippet =
(attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => {
switch (attributes.activation) {
case 'Relu':
return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'};
case 'Sigmoid':
return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'};
case 'Clip':
return {
activationFunction:
`const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`,
applyActivation: 'value = clamp(value, clip_min_, clip_max_);'
};
// TODO: adding other activations that can be fused.
default:
return {activationFunction: '', applyActivation: ''};
}
};
export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): {
activationFunction: string; applyActivation: string;
} => {
switch (attributes.activation) {
case 'Relu':
return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'};
case 'Sigmoid':
return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'};
case 'Clip':
return {
activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`,
applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' :
'value = clamp(value, clip_min_, clip_max_);'
};
// TODO: adding other activations that can be fused.
default:
return {activationFunction: '', applyActivation: ''};
}
};

export const parseInternalActivationAttributes =
(attributes: Record<string, unknown>|undefined): InternalActivationAttributes => {
Expand Down
112 changes: 112 additions & 0 deletions js/web/test/data/ops/fused-conv.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
[
{
"name": "conv without bias addition A",
"operator": "FusedConv",
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }
],
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 670, 770],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
},
{
"name": "T[1]",
"inputs": [
{
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 0, 390, 430],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "NHWC conv without bias addition A",
"operator": "Conv",
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }
],
"opset": { "domain": "com.ms.internal.nhwc", "version": 11 },
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 3, 3, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 670, 770],
"dims": [1, 2, 2, 1],
"type": "float32"
}
]
},
{
"name": "T[3]",
"inputs": [
{
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
"dims": [1, 3, 3, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 0, 390, 430],
"dims": [1, 2, 2, 1],
"type": "float32"
}
]
}
]
}
]
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/js/fused_conv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/js/operators/conv.h"
namespace onnxruntime {
namespace contrib {
namespace js {

ONNX_OPERATOR_KERNEL_EX(
FusedConv,
kMSDomain,
1,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
onnxruntime::js::Conv<false, true>);

} // namespace js
} // namespace contrib
} // namespace onnxruntime
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/js/js_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand All @@ -23,7 +24,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization)>};
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>};

for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();
Expand Down
31 changes: 27 additions & 4 deletions onnxruntime/core/optimizer/conv_activation_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "core/optimizer/conv_activation_fusion.h"

#include <string_view>

#include <string>
#include "core/common/inlined_containers.h"
#include "core/framework/tensorprotoutils.h"
#include "core/mlas/inc/mlas.h"
Expand Down Expand Up @@ -174,9 +174,29 @@ using NTO = NodesToOptimize;

class FuseConvActivationAction : public ReplaceWithNew {
private:
std::string OpType(const RuntimeState&) const override { return "FusedConv"; }
std::string OpType(const RuntimeState& runtime_state) const override {
const auto& domain = runtime_state.selected_nodes.Target().Domain();
const auto& op_type = runtime_state.selected_nodes.Target().OpType();
if (domain == kOnnxDomain) {
if (op_type == "Conv") {
return "FusedConv";
}
} else if (domain == kMSDomain) {
if (op_type == "NhwcConv") {
return "NhwcFusedConv";
}
} else if (domain == kMSInternalNHWCDomain) {
if (op_type == "Conv") {
return "Conv";
}
}
ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain);
}

std::string Domain(const RuntimeState&) const override { return kMSDomain; }
std::string Domain(const RuntimeState& runtime_state) const override {
auto domain = runtime_state.selected_nodes.Target().Domain();
return domain == kOnnxDomain ? kMSDomain : domain;
}

NodeAttributes ExtraAttributes(const RuntimeState& state) const override {
NodeAttributes extra_fused_conv_attributes;
Expand Down Expand Up @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) {
const auto name = "ConvAct";
auto action = std::make_unique<actions::FuseConvActivationAction>();
#if !defined(ORT_MINIMAL_BUILD)
const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain);
const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain);
auto selector = std::make_unique<selectors::ConvActivationSelector>();
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},

registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}},
std::move(selector), std::move(action));
#else
registry.RegisterAction(name, std::move(action));
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/core/optimizer/conv_add_act_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew {
void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) {
auto action = std::make_unique<actions::FuseConvAddActivationAction>();
auto selector = std::make_unique<selectors::ConvAddActivationSelector>();
registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}},
std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain);
registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}},
std::move(selector), std::move(action));
auto action_nhwc = std::make_unique<actions::FuseConvAddActivationAction>();
auto selector_nhwc = std::make_unique<selectors::ConvAddActivationSelector>();
registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}},
std::move(selector_nhwc), std::move(action_nhwc));
}

SelectorActionRegistry CreateSelectorActionRegistry() {
Expand Down
13 changes: 7 additions & 6 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kDmlExecutionProvider};
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider};
const InlinedHashSet<std::string_view> cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider,
onnxruntime::kAclExecutionProvider,
onnxruntime::kArmNNExecutionProvider,
onnxruntime::kJsExecutionProvider};

#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
Expand All @@ -296,7 +297,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<MatMulIntegerToFloatFusion>(cpu_ep));
transformers.emplace_back(std::make_unique<DynamicQuantizeMatMulFusion>(cpu_ep));

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_eps));
transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_cuda_rocm_acl_armnn_js_eps));

transformers.emplace_back(std::make_unique<GeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
Expand Down
Loading

0 comments on commit a2e9ba7

Please sign in to comment.