From 84f20975a17a3799003f3aef4a45c45748849a37 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 4 Oct 2023 16:05:02 -0700 Subject: [PATCH] Added FusedConv --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 5 +- .../test/data/ops/fused-conv-transpose.jsonc | 55 +++++++++ js/web/test/data/ops/fused-conv.jsonc | 112 ++++++++++++++++++ onnxruntime/contrib_ops/js/fused_conv.cc | 20 ++++ .../contrib_ops/js/fused_conv_transpose.cc | 13 +- .../contrib_ops/js/fused_conv_transpose.h | 24 ---- .../contrib_ops/js/js_contrib_kernels.cc | 2 + .../core/providers/js/operators/conv.cc | 4 +- .../core/providers/js/operators/conv.h | 19 ++- .../providers/js/operators/conv_transpose.cc | 4 +- .../providers/js/operators/conv_transpose.h | 8 +- 13 files changed, 220 insertions(+), 48 deletions(-) create mode 100644 js/web/test/data/ops/fused-conv.jsonc create mode 100644 onnxruntime/contrib_ops/js/fused_conv.cc delete mode 100644 onnxruntime/contrib_ops/js/fused_conv_transpose.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index ddb7832a4e315..8d2d641964a07 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -38,6 +38,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| FusedConv | com.microsoft(1+) | | | FusedConvTranspose | com.microsoft(1+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | GatherElements | ai.onnx(11-12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 77270ab913002..c56fe66acb1fc 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -62,6 +62,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['FusedConv', [conv, parseConvAttributes]], ['FusedConvTranspose', [convTranspose, parseConvTransposeAttributes]], ['Gather', [gather, parseGatherAttributes]], ['GatherElements', [gatherElements, parseGatherElementsAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index e6d4039d8131b..9d231ed6bf266 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -243,8 +243,9 @@ export const createConv2DMatMulProgramInfo = ${declareFunctions} ${ conv2dCommonSnippet( - isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2], t)} + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, + attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1], + elementsSize[2], t)} ${ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : diff --git a/js/web/test/data/ops/fused-conv-transpose.jsonc b/js/web/test/data/ops/fused-conv-transpose.jsonc index 9a72e76db5105..b1a61c0162b44 100644 --- a/js/web/test/data/ops/fused-conv-transpose.jsonc +++ b/js/web/test/data/ops/fused-conv-transpose.jsonc @@ -53,5 +53,60 @@ ] } ] + }, + { + "name": "NHWC FusedConvTranspose without bias addition A", + "operator": "ConvTranspose", + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation", "data": "Relu", "type": "string" } + ], + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 2, 2, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 3, 3, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, -20, -30, 40], + "dims": [1, 2, 2, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 0, 0, 0, 0, 0, 0, 0, 160], + "dims": [1, 3, 3, 1], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc new file mode 100644 index 0000000000000..812e9d7c2def0 --- /dev/null +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -0,0 +1,112 @@ +[ + { + "name": "conv without bias addition A", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + }, + { + "name": "T[1]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv without bias addition A", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "Relu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 11 }, + "cases": [ + { + "name": "T[2]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [370, 470, 670, 770], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + }, + { + "name": "T[3]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/fused_conv.cc b/onnxruntime/contrib_ops/js/fused_conv.cc new file mode 100644 index 0000000000000..76402f0681976 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fused_conv.cc @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/operators/conv.h" +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + onnxruntime::js::Conv); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fused_conv_transpose.cc b/onnxruntime/contrib_ops/js/fused_conv_transpose.cc index 0c6f40c4088ec..3b1a56a211413 100644 --- a/onnxruntime/contrib_ops/js/fused_conv_transpose.cc +++ b/onnxruntime/contrib_ops/js/fused_conv_transpose.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "fused_conv_transpose.h" +#include "core/providers/js/operators/conv_transpose.h" namespace onnxruntime { namespace contrib { namespace js { @@ -13,16 +13,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, KernelDefBuilder() .TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedConvTranspose); - -ONNX_OPERATOR_KERNEL_EX( - FusedConvTranspose, - kMSInternalNHWCDomain, - 1, - kJsExecutionProvider, - KernelDefBuilder() - .TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedConvTranspose); + onnxruntime::js::ConvTranspose); } // namespace js } // namespace contrib diff --git a/onnxruntime/contrib_ops/js/fused_conv_transpose.h b/onnxruntime/contrib_ops/js/fused_conv_transpose.h deleted file mode 100644 index 47bc9117d4714..0000000000000 --- a/onnxruntime/contrib_ops/js/fused_conv_transpose.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/common/common.h" -#include "core/providers/js/js_kernel.h" -#include "core/providers/js/operators/conv_transpose.h" - -namespace onnxruntime { -namespace contrib { -namespace js { - -using onnxruntime::js::ConvTranspose; - -template -class FusedConvTranspose : public ConvTranspose { - public: - explicit FusedConvTranspose(const OpKernelInfo& info) : ConvTranspose(info, true) { - } -}; - -} // namespace js -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index a74e59c5a3cb1..fc4e4c0dab7e1 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -9,6 +9,7 @@ namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConvTranspose); template <> @@ -21,6 +22,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index 2e07124dcd901..3551122ca68f6 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -15,7 +15,8 @@ ONNX_OPERATOR_KERNEL_EX( 11, kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), - Conv); + Conv); + ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, @@ -23,6 +24,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), Conv); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Conv, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index fdf3e5b6c6b66..97f9a4930e837 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -template +template class Conv : public JsKernel { public: Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { @@ -17,6 +17,11 @@ class Conv : public JsKernel { if (conv_attrs_.kernel_shape_specified) { ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK()); } + if (has_activation) { + ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_attrs_.activation)); + } else { + conv_attrs_.activation = ""; + } int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); @@ -32,7 +37,8 @@ class Conv : public JsKernel { "kernel_shape" : [$4], "pads" : [ $5, $6 ], "strides" : [$7], - "w_is_const" : () JS_ARROW(!!HEAP8[$9]) + "w_is_const" : () JS_ARROW(!!HEAP8[$9]), + "activation" : UTF8ToString($10) }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), @@ -42,7 +48,8 @@ class Conv : public JsKernel { static_cast(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0), static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str()); } else { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $13 ? "NHWC" : "NCHW", @@ -52,7 +59,8 @@ class Conv : public JsKernel { "kernel_shape" : [ $5, $6 ], "pads" : [ $7, $8, $9, $10 ], "strides" : [ $11, $12 ], - "w_is_const" : () JS_ARROW(!!HEAP8[$14]) + "w_is_const" : () JS_ARROW(!!HEAP8[$14]), + "activation" : UTF8ToString($15) }), static_cast(conv_attrs_.auto_pad), static_cast(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0), @@ -67,7 +75,8 @@ class Conv : public JsKernel { static_cast(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0), static_cast(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0), static_cast(channels_last), - reinterpret_cast(&w_is_const_)); + reinterpret_cast(&w_is_const_), + conv_attrs_.activation.c_str()); } } diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 2228343e1e6e3..a71ec23fa63dc 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -14,7 +14,8 @@ ONNX_OPERATOR_KERNEL_EX( 11, kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), - ConvTranspose); + ConvTranspose); + ONNX_OPERATOR_KERNEL_EX( ConvTranspose, kOnnxDomain, @@ -22,6 +23,7 @@ ONNX_OPERATOR_KERNEL_EX( kJsExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ConvTranspose); + ONNX_OPERATOR_VERSIONED_KERNEL_EX( ConvTranspose, kOnnxDomain, diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index b49e8217d4fd2..0507ce3f16b2d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -5,16 +5,16 @@ #include #include "core/common/gsl.h" -#include "core/providers/cpu/nn/conv_transpose_attributes.h" #include "core/providers/js/js_kernel.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: - ConvTranspose(const OpKernelInfo& info, bool hasActivation = false) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { + ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { TensorShapeVector kernel_shape; - if (hasActivation) { + if (has_activation) { ORT_THROW_IF_ERROR(info.GetAttr("activation", &conv_transpose_attrs_.activation)); } else { conv_transpose_attrs_.activation = "";