From a2e9ba72d5a5f61e1324ffc2a80d748d01be9120 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Wed, 1 Nov 2023 15:34:51 -0700 Subject: [PATCH] [JS/Web]Added FusedConv. (#17766) ### Description Added FusedConv and FusedConvTranspose ### Motivation and Context Improve performance --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + .../webgpu/ops/3rd-party/activation_util.ts | 4 +- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 5 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 4 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 6 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 4 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 37 +++--- js/web/test/data/ops/fused-conv.jsonc | 112 ++++++++++++++++++ onnxruntime/contrib_ops/js/fused_conv.cc | 20 ++++ .../contrib_ops/js/js_contrib_kernels.cc | 5 +- .../core/optimizer/conv_activation_fusion.cc | 31 ++++- .../core/optimizer/conv_add_act_fusion.cc | 7 +- .../core/optimizer/graph_transformer_utils.cc | 13 +- .../selector_action_transformer.cc | 20 ++-- .../selector_action_transformer.h | 17 ++- .../core/providers/js/operators/conv.cc | 2 + .../core/providers/js/operators/conv.h | 78 ++++++++---- .../providers/js/operators/conv_transpose.cc | 2 + .../providers/js/operators/conv_transpose.h | 55 ++++++--- .../test/optimizer/graph_transform_test.cc | 13 +- 21 files changed, 339 insertions(+), 98 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv.jsonc create mode 100644 onnxruntime/contrib_ops/js/fused_conv.cc diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 44003021293b0..5b94a4a510934 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -40,6 +40,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | | Gelu | com.microsoft(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 40309c1849bcc..a4d51e68b6a25 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -67,6 +67,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], ['Gelu', [unaryOps.gelu]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 22b91d680a9b4..6481a6b21d723 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -41,12 +41,12 @@ export const activationFnSnippet = if (!activation) { return ''; } - // TODO: add implementations return ''; }; export const biasActivationSnippet = (hasBias: boolean, activation?: Activation): string => ` ${hasBias ? 'value = value + getBiasByOutputCoords(coords);' : ''} - ${activation ? 'value = activation(value, coords);' : ''} + // TODO uncomment the following line when activation is supported above. + // ${activation ? 'value = activation(value, coords);' : ''} `; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 01ddca520deed..fbb936a045b9c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -242,8 +242,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 840360223c75a..a95d3830f34eb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -236,7 +236,9 @@ export const createConv2DTransposeMatMulProgramInfo = const dimBOuter : i32 = ${dimBOuter}; const dimInner : i32 = ${dimInner}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + conv2dTransposeCommonSnippet( + isChannelsLast, hasBias, attributes.activation.toLowerCase() as Activation, false, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 1032869412462..0a0f29db6a494 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo} from '../../types'; import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; -import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -440,7 +440,7 @@ export const createMatmulProgramInfo = const dimInner = aShape[aShape.length - 1]; const dimBOuter = bShape[bShape.length - 1]; const isVec4 = dimInner % 4 === 0 && dimBOuter % 4 === 0; - const {activationFunction, applyActivation} = getActicationSnippet(activationAttributes); + const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, isVec4); // TODO: fine tune size const elementsPerThread = dimAOuter <= 8 ? [4, 1, 1] : [4, 4, 1]; @@ -473,8 +473,8 @@ export const createMatmulProgramInfo = const dimBOuter: i32 = ${dimBOuter}; const dimInner: i32 = ${dimInner}; ${shaderHelper.declareVariables(...inputVariables, output)} - ${declareFunctions} ${activationFunction} + ${declareFunctions} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7abf022928ade..8bfa722dd0909 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo} from '../types'; import {inputVariable, outputVariable, ShaderHelper} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActicationSnippet} from './fuse-utils'; +import {getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -22,7 +22,7 @@ export const createGroupedConvProgramInfo = const wShape = inputs[1].dims; const outputChannelsPerGroup = wShape[0] / attributes.group; - const {activationFunction, applyActivation} = getActicationSnippet(attributes); + const {activationFunction, applyActivation} = getActivationSnippet(attributes); const isChannelLast = attributes.format === 'NHWC'; const outputShape = calculateOutputShape( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 92105859a8c0e..956ef18eb5cfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -10,24 +10,25 @@ export interface InternalActivationAttributes { readonly activationCacheKey: string; } -export const getActicationSnippet = - (attributes: InternalActivationAttributes): {activationFunction: string; applyActivation: string} => { - switch (attributes.activation) { - case 'Relu': - return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; - case 'Sigmoid': - return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; - case 'Clip': - return { - activationFunction: - `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, - applyActivation: 'value = clamp(value, clip_min_, clip_max_);' - }; - // TODO: adding other activations that can be fused. - default: - return {activationFunction: '', applyActivation: ''}; - } - }; +export const getActivationSnippet = (attributes: InternalActivationAttributes, isVec4 = false): { + activationFunction: string; applyActivation: string; +} => { + switch (attributes.activation) { + case 'Relu': + return {activationFunction: '', applyActivation: 'value = max(value, 0.0);'}; + case 'Sigmoid': + return {activationFunction: '', applyActivation: 'value = (1.0 / (1.0 + exp(-value)));'}; + case 'Clip': + return { + activationFunction: `const clip_min_=f32(${attributes.clipMin!});const clip_max_=f32(${attributes.clipMax!});`, + applyActivation: isVec4 ? 'value = clamp(value, vec4(clip_min_), vec4(clip_max_));' : + 'value = clamp(value, clip_min_, clip_max_);' + }; + // TODO: adding other activations that can be fused. + default: + return {activationFunction: '', applyActivation: ''}; + } +}; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 4641b006a7785..24d327576ecd9 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -11,6 +11,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -23,7 +24,9 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index c090ab2a6cc9b..d27603e4ab3a1 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -4,7 +4,7 @@ #include "core/optimizer/conv_activation_fusion.h" #include - +#include #include "core/common/inlined_containers.h" #include "core/framework/tensorprotoutils.h" #include "core/mlas/inc/mlas.h" @@ -174,9 +174,29 @@ using NTO = NodesToOptimize; class FuseConvActivationAction : public ReplaceWithNew { private: - std::string OpType(const RuntimeState&) const override { return "FusedConv"; } + std::string OpType(const RuntimeState& runtime_state) const override { + const auto& domain = runtime_state.selected_nodes.Target().Domain(); + const auto& op_type = runtime_state.selected_nodes.Target().OpType(); + if (domain == kOnnxDomain) { + if (op_type == "Conv") { + return "FusedConv"; + } + } else if (domain == kMSDomain) { + if (op_type == "NhwcConv") { + return "NhwcFusedConv"; + } + } else if (domain == kMSInternalNHWCDomain) { + if (op_type == "Conv") { + return "Conv"; + } + } + ORT_THROW("Unsupported operator: ", op_type, " and domain: ", domain); + } - std::string Domain(const RuntimeState&) const override { return kMSDomain; } + std::string Domain(const RuntimeState& runtime_state) const override { + auto domain = runtime_state.selected_nodes.Target().Domain(); + return domain == kOnnxDomain ? kMSDomain : domain; + } NodeAttributes ExtraAttributes(const RuntimeState& state) const override { NodeAttributes extra_fused_conv_attributes; @@ -260,8 +280,11 @@ void RegisterConvActivationFusionRules(SelectorActionRegistry& registry) { const auto name = "ConvAct"; auto action = std::make_unique(); #if !defined(ORT_MINIMAL_BUILD) + const std::string msInternalNHWCDomainConv = SelectorActionRegistry::OpVersionsMapKey("Conv", kMSInternalNHWCDomain); + const std::string msDomainConv = SelectorActionRegistry::OpVersionsMapKey("NhwcConv", kMSDomain); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}}, + + registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}, {msInternalNHWCDomainConv, {11}}, {msDomainConv, {1}}}, std::move(selector), std::move(action)); #else registry.RegisterAction(name, std::move(action)); diff --git a/onnxruntime/core/optimizer/conv_add_act_fusion.cc b/onnxruntime/core/optimizer/conv_add_act_fusion.cc index 7c8bfeaec5f0f..6f90eaf07ef4d 100644 --- a/onnxruntime/core/optimizer/conv_add_act_fusion.cc +++ b/onnxruntime/core/optimizer/conv_add_act_fusion.cc @@ -287,12 +287,9 @@ class FuseConvAddActivationAction : public ReplaceWithNew { void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) { auto action = std::make_unique(); auto selector = std::make_unique(); - registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}}, + std::string msDomainNhwcFusedConv = SelectorActionRegistry::OpVersionsMapKey("NhwcFusedConv", kMSDomain); + registry.RegisterSelectorAndAction("ConvAddAct", {{"Conv", {1, 11}}, {msDomainNhwcFusedConv, {1, 11}}}, std::move(selector), std::move(action)); - auto action_nhwc = std::make_unique(); - auto selector_nhwc = std::make_unique(); - registry.RegisterSelectorAndAction("NhwcFusedConvAct", {{"NhwcFusedConv", {1, 11}}}, - std::move(selector_nhwc), std::move(action_nhwc)); } SelectorActionRegistry CreateSelectorActionRegistry() { diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 5a441b1d1701e..86b126f2c7c31 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -270,11 +270,12 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc index e182b6c695d2f..546d52b6f1682 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.cc @@ -3,9 +3,10 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" -#include #include +#include #include +#include #include #include "core/graph/op_identifier_utils.h" @@ -56,9 +57,9 @@ const SelectorActionRegistry::Entry* SelectorActionRegistry::LookUp(const std::s } #if !defined(ORT_MINIMAL_BUILD) -auto SelectorActionRegistry::LookUpByOpType(const std::string& op_type) const +auto SelectorActionRegistry::LookUpByOpTypeAndDomain(const std::string& op_type, const std::string& domain) const -> std::vector> { - const auto [range_begin, range_end] = op_type_to_entry_.equal_range(op_type); + const auto [range_begin, range_end] = op_type_to_entry_.equal_range(OpVersionsMapKey(op_type, domain)); std::vector> result{}; result.reserve(std::distance(range_begin, range_end)); std::transform(range_begin, range_end, std::back_inserter(result), @@ -93,20 +94,15 @@ static Status MatchAndProcess( Status status = Status::OK(); do { - // TODO: for now this just needs to support ONNX and Micrsoft Domain ops. - // If we ever had a transformer that was going to target non-ONNX ops, - // we'd need to rework a few things to include the op domain in the matches - if (node.Domain() != kOnnxDomain && node.Domain() != kMSDomain) { - break; - } - std::optional node_selection_opt{}; const SelectorActionRegistry::Entry* selector_action_entry_ptr = nullptr; - const auto selector_action_entries = selector_action_registry.LookUpByOpType(node.OpType()); + const auto selector_action_entries = + selector_action_registry.LookUpByOpTypeAndDomain(node.OpType(), node.Domain()); + std::string key = SelectorActionRegistry::OpVersionsMapKey(node.OpType(), node.Domain()); for (const auto& entry : selector_action_entries) { // check the supported versions if specified - const auto& versions = entry->ops_and_versions.find(node.OpType())->second; + const auto& versions = entry->ops_and_versions.find(key)->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node.SinceVersion()) == versions.cend()) { continue; diff --git a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h index 7eb162cc693f1..5caa949ebbe93 100644 --- a/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h +++ b/onnxruntime/core/optimizer/selectors_actions/selector_action_transformer.h @@ -38,8 +38,20 @@ struct NodeSelector { // class to manage a set of selector and associated actions class SelectorActionRegistry { public: + // The key is a string representing the op, optionally specifying the domain using ':' as the + // separator with domain as the first part and operator as the second part, ":" or "". + // For ops in kOnnxDomain, the domain should be left unspecified (""). + // For ops in other domains, the domain should be specified (":"). + // Ex: "Conv", "com.microsoft:Conv", "com.ms.internal.nhwc:Conv" using OpVersionsMap = std::unordered_map>; + // Helper function to create a key to OpVersionsMap using domain and op_type. + static std::string OpVersionsMapKey(std::string_view op_type, std::string_view domain = kOnnxDomain) { + return (domain == kOnnxDomain) + ? std::string{op_type} + : std::string{domain} + ":" + std::string{op_type}; + } + struct Entry { Entry(const std::string& name_in, #if !defined(ORT_MINIMAL_BUILD) @@ -95,14 +107,15 @@ class SelectorActionRegistry { #if !defined(ORT_MINIMAL_BUILD) // return registered Entry or nullptr if not found - auto LookUpByOpType(const std::string& op_type) const -> std::vector>; + auto LookUpByOpTypeAndDomain(const std::string& op_type, + const std::string& domain) const -> std::vector>; #endif // !defined(ORT_MINIMAL_BUILD) private: std::unordered_map name_to_entry_; #if !defined(ORT_MINIMAL_BUILD) - // auxiliary mapping to enable lookup by op type + // auxiliary mapping to enable lookup by op type or "domain:op type" std::unordered_multimap op_type_to_entry_; #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index 2e07124dcd901..68336c996a863 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -16,6 +16,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, @@ -23,6 +24,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Conv, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index fdf3e5b6c6b66..3a01a4aa46be4 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -3,23 +3,42 @@ #pragma once +#include +#include + #include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" namespace onnxruntime { namespace js { -template -class Conv : public JsKernel { +class ConvBase : public JsKernel { public: - Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { + ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info), + conv_attrs_(info), + w_is_const_(false) { + std::vector activation_params; TensorShapeVector kernel_shape; + const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size(); + std::vector local_pads(pads_vec_size, 0); + for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) { + local_pads[i] = gsl::narrow_cast(conv_attrs_.pads[i]); + } + if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - + if (is_fused_conv) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); + ORT_ENFORCE(info.GetAttrs("activation_params", activation_params).IsOK()); + } else { + conv_attrs_.activation = info.GetAttrOrDefault("activation", ""); + activation_params = info.GetAttrsOrDefault("activation_params", activation_params); + } + const auto* activation_params_ptr = activation_params.size() > 0 ? activation_params.data() : nullptr; int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); - + auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_attrs_.dilations.size() == 1 || (conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || @@ -30,44 +49,52 @@ class Conv : public JsKernel { "dilations" : [$2], "group" : $3, "kernel_shape" : [$4], - "pads" : [ $5, $6 ], + "pads" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : [], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "activation" : UTF8ToString($10), + "activation_params" : $11 ? Array.from(HEAPF32.subarray($12, $12 + $11)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), + static_cast(kernel_shape_0), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ - "format" : $13 ? "NHWC" : "NCHW", + "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, "dilations" : [ $2, $3 ], "group" : $4, "kernel_shape" : [ $5, $6 ], - "pads" : [ $7, $8, $9, $10 ], - "strides" : [ $11, $12 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + "pads" : $7 ? Array.from(HEAP32.subarray($8, $8 + $7)) : [], + "strides" : [ $9, $10 ], + "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "activation" : UTF8ToString($13), + "activation_params" : $14 ? Array.from(HEAPF32.subarray($15, $15 + $14)) : [] }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), static_cast(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0), static_cast(conv_attrs_.group), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0), - static_cast(conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0), - static_cast(conv_attrs_.pads.size() > 0 ? conv_attrs_.pads[0] : 0), - static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), - static_cast(conv_attrs_.pads.size() > 2 ? conv_attrs_.pads[2] : 0), - static_cast(conv_attrs_.pads.size() > 3 ? conv_attrs_.pads[3] : 0), + static_cast(kernel_shape_0), + static_cast(kernel_shape_1), + static_cast(local_pads.size()), + reinterpret_cast(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2, static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str(), + activation_params.size(), + reinterpret_cast(activation_params_ptr) >> 2); } } @@ -94,5 +121,12 @@ class Conv : public JsKernel { // Tensor w_transposed_; }; +template +class Conv : public ConvBase { + public: + explicit Conv(const OpKernelInfo& info) : ConvBase(info, is_channels_last, is_fused_conv) { + } +}; + } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 2228343e1e6e3..f7f0ab22b7006 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -15,6 +15,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_KERNEL_EX( ConvTranspose, kOnnxDomain, @@ -22,6 +23,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( ConvTranspose, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 18ef73268005d..5d30dc851e00f 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -4,26 +4,45 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { TensorShapeVector kernel_shape; + if (is_fused_convtranspose) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_transpose_attrs_.activation)); + } else { + conv_transpose_attrs_.activation = info.GetAttrOrDefault("activation", ""); + } + if (conv_transpose_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } - int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); + std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), + conv_transpose_attrs_.output_shape.end()); + std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), + conv_transpose_attrs_.output_padding.end()); + const auto* local_output_padding_ptr = + local_output_padding.size() > 0 ? local_output_padding.data() : nullptr; + const auto* local_output_shape_ptr = + local_output_shape.size() > 0 ? local_output_shape.data() : nullptr; // currently only support Conv 1D/2D. TODO: support Conv3D and other if (conv_transpose_attrs_.dilations.size() == 1 || (conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() == 1) || conv_transpose_attrs_.strides.size() == 1) { + auto dilations = conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0; + auto kernel_shape_0 = conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0; + auto pads_0 = conv_transpose_attrs_.pads.size() > 0 ? conv_transpose_attrs_.pads[0] : 0; + auto pads_1 = conv_transpose_attrs_.pads.size() > 1 ? conv_transpose_attrs_.pads[1] : 0; + auto strides = conv_transpose_attrs_.strides.size() > 0 ? conv_transpose_attrs_.strides[0] : 0; JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $8 ? "NHWC" : "NCHW", "autoPad" : $1, @@ -34,21 +53,23 @@ class ConvTranspose : public JsKernel { "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), "outputPadding" : $10 ? Array.from(HEAP32.subarray($11, $11 + $10)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [] + "outputShape" : $12 ? Array.from(HEAP32.subarray($13, $13 + $12)) : [], + "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), - static_cast(conv_transpose_attrs_.dilations.size() > 0 ? conv_transpose_attrs_.dilations[0] : 0), + static_cast(dilations), static_cast(conv_transpose_attrs_.group), - static_cast(conv_transpose_attrs_.kernel_shape_specified && kernel_shape.size() > 0) ? kernel_shape[0] : 0, - static_cast(conv_transpose_attrs_.pads.size()), - static_cast(conv_transpose_attrs_.pads.size() > 1) ? conv_transpose_attrs_.pads[1] : 0, - static_cast(conv_transpose_attrs_.strides.size() > 0) ? conv_transpose_attrs_.strides[0] : 0, + static_cast(kernel_shape_0), + static_cast(pads_0), + static_cast(pads_1), + static_cast(strides), static_cast(channels_last), reinterpret_cast(&w_is_const_), - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_padding.size() > 0 ? conv_transpose_attrs_.output_padding.data() : nullptr) >> 2, - gsl::narrow_cast(conv_transpose_attrs_.output_shape.size()), - reinterpret_cast(conv_transpose_attrs_.output_shape.size() > 0 ? conv_transpose_attrs_.output_shape.data() : nullptr) >> 2); + gsl::narrow_cast(local_output_padding.size()), + reinterpret_cast(local_output_padding_ptr) >> 2, + gsl::narrow_cast(local_output_shape.size()), + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } else { constexpr size_t pads_vec_size = 4; constexpr size_t strides_vec_size = 2; @@ -59,8 +80,6 @@ class ConvTranspose : public JsKernel { std::vector local_strides(strides_vec_size, 0); std::vector local_dilations(dialations_vec_size, 0); std::vector local_kernel_shape; - std::vector local_output_shape(conv_transpose_attrs_.output_shape.begin(), conv_transpose_attrs_.output_shape.end()); - std::vector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); if (conv_transpose_attrs_.kernel_shape_specified) { for (size_t i = 0; i < kernel_shape.size() && i < kernel_shape_vec_size; ++i) { local_kernel_shape.push_back(gsl::narrow_cast(kernel_shape[i])); @@ -91,7 +110,8 @@ class ConvTranspose : public JsKernel { "strides" : Array.from(HEAP32.subarray($6, $6 + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), "outputPadding" : ($9 > 0) ? Array.from(HEAP32.subarray($10, $10 + $9)) : [], - "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [] + "outputShape" : ($11 > 0) ? Array.from(HEAP32.subarray($12, $12 + $11)) : [], + "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), reinterpret_cast(local_dilations.data()) >> 2, @@ -102,9 +122,10 @@ class ConvTranspose : public JsKernel { static_cast(channels_last), reinterpret_cast(&w_is_const_), gsl::narrow_cast(local_output_padding.size()), - reinterpret_cast(local_output_padding.size() > 0 ? local_output_padding.data() : nullptr) >> 2, + reinterpret_cast(local_output_padding_ptr) >> 2, gsl::narrow_cast(local_output_shape.size()), - reinterpret_cast(local_output_shape.size() > 0 ? local_output_shape.data() : nullptr) >> 2); + reinterpret_cast(local_output_shape_ptr) >> 2, + conv_transpose_attrs_.activation.c_str()); } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 46b95a127b75c..a6aa4b946f397 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1438,7 +1438,7 @@ TEST_F(GraphTransformationTests, NotWhereFusion) { ASSERT_TRUE(op_to_count["Not"] == 1); // can't remove Not if it is graph output/ has consumer that's not where } -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) +#if (defined(USE_CUDA) || defined(USE_JSEP)) && !defined(DISABLE_CONTRIB_OPS) // Conv->Add->Relu will be transformed to FusedConv TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; @@ -1618,6 +1618,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCudaExecutionProvider); } +#elif defined(USE_JSEP) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kJsExecutionProvider); + } #endif std::map op_to_count_before_fusion = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count_before_fusion[model.second] >= 1); @@ -1632,6 +1636,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { std::set cuda_rocm_supported = {"Relu"}; if (cuda_rocm_supported.find(model.second) == cuda_rocm_supported.end()) { ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_EQ(op_to_count_after_fusion[model.second], 0); + } +#elif defined(USE_JSEP) + std::set js_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu"}; + if (js_supported.find(model.second) == js_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); }