Skip to content

Commit

Permalink
Added FusedConv
Browse files Browse the repository at this point in the history
  • Loading branch information
satyajandhyala committed Oct 4, 2023
1 parent 7a16434 commit 84f2097
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 48 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Do not modify directly.*
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
| FusedConvTranspose | com.microsoft(1+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherElements | ai.onnx(11-12,13+) | |
Expand Down
1 change: 1 addition & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['FusedConvTranspose', [convTranspose, parseConvTransposeAttributes]],
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
Expand Down
5 changes: 3 additions & 2 deletions js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,9 @@ export const createConv2DMatMulProgramInfo =
${declareFunctions}
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0],
elementsSize[1], elementsSize[2], t)}
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias,
attributes.activation.toLowerCase() as Activation, false, elementsSize[0], elementsSize[1],
elementsSize[2], t)}
${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
Expand Down
55 changes: 55 additions & 0 deletions js/web/test/data/ops/fused-conv-transpose.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,60 @@
]
}
]
},
{
"name": "NHWC FusedConvTranspose without bias addition A",
"operator": "ConvTranspose",
"opset": { "domain": "com.ms.internal.nhwc", "version": 11 },
"attributes": [
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "activation", "data": "Relu", "type": "string" }
],
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [10, 20, 30, 40],
"dims": [1, 2, 2, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [10, 40, 40, 60, 200, 160, 90, 240, 160],
"dims": [1, 3, 3, 1],
"type": "float32"
}
]
},
{
"name": "T[3]",
"inputs": [
{
"data": [10, -20, -30, 40],
"dims": [1, 2, 2, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [10, 0, 0, 0, 0, 0, 0, 0, 160],
"dims": [1, 3, 3, 1],
"type": "float32"
}
]
}
]
}
]
112 changes: 112 additions & 0 deletions js/web/test/data/ops/fused-conv.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
[
{
"name": "conv without bias addition A",
"operator": "FusedConv",
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }
],
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 670, 770],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
},
{
"name": "T[1]",
"inputs": [
{
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 0, 390, 430],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "NHWC conv without bias addition A",
"operator": "Conv",
"attributes": [
{ "name": "activation", "data": "Relu", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }
],
"opset": { "domain": "com.ms.internal.nhwc", "version": 11 },
"cases": [
{
"name": "T[2]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 3, 3, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 670, 770],
"dims": [1, 2, 2, 1],
"type": "float32"
}
]
},
{
"name": "T[3]",
"inputs": [
{
"data": [10, 20, -30, -40, -50, -60, 70, 80, 90],
"dims": [1, 3, 3, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0, 0, 390, 430],
"dims": [1, 2, 2, 1],
"type": "float32"
}
]
}
]
}
]
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/js/fused_conv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/js/operators/conv.h"
namespace onnxruntime {
namespace contrib {
namespace js {

ONNX_OPERATOR_KERNEL_EX(
FusedConv,
kMSDomain,
1,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
onnxruntime::js::Conv<false, true>);

} // namespace js
} // namespace contrib
} // namespace onnxruntime
13 changes: 2 additions & 11 deletions onnxruntime/contrib_ops/js/fused_conv_transpose.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fused_conv_transpose.h"
#include "core/providers/js/operators/conv_transpose.h"
namespace onnxruntime {
namespace contrib {
namespace js {
Expand All @@ -13,16 +13,7 @@ ONNX_OPERATOR_KERNEL_EX(
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConvTranspose<false>);

ONNX_OPERATOR_KERNEL_EX(
FusedConvTranspose,
kMSInternalNHWCDomain,
1,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedConvTranspose<true>);
onnxruntime::js::ConvTranspose<false, true>);

} // namespace js
} // namespace contrib
Expand Down
24 changes: 0 additions & 24 deletions onnxruntime/contrib_ops/js/fused_conv_transpose.h

This file was deleted.

2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/js/js_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace js {

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConvTranspose);

template <>
Expand All @@ -21,6 +22,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConvTranspose)>};

for (auto& function_table_entry : function_table) {
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/js/operators/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Conv<true>);
Conv<true, true>);

ONNX_OPERATOR_KERNEL_EX(
Conv,
kOnnxDomain,
11,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Conv<false>);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Conv,
kOnnxDomain,
Expand Down
19 changes: 14 additions & 5 deletions onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
namespace onnxruntime {
namespace js {

template <bool is_channels_last>
template <bool is_channels_last, bool has_activation = false>
class Conv : public JsKernel {
public:
Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) {
TensorShapeVector kernel_shape;
if (conv_attrs_.kernel_shape_specified) {
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK());
}
if (has_activation) {
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &conv_attrs_.activation));
} else {
conv_attrs_.activation = "";
}

int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);

Expand All @@ -32,7 +37,8 @@ class Conv : public JsKernel {
"kernel_shape" : [$4],
"pads" : [ $5, $6 ],
"strides" : [$7],
"w_is_const" : () JS_ARROW(!!HEAP8[$9])
"w_is_const" : () JS_ARROW(!!HEAP8[$9]),
"activation" : UTF8ToString($10)
}),
static_cast<int32_t>(conv_attrs_.auto_pad),
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
Expand All @@ -42,7 +48,8 @@ class Conv : public JsKernel {
static_cast<int32_t>(conv_attrs_.pads.size() > 1 ? conv_attrs_.pads[1] : 0),
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_));
reinterpret_cast<int32_t>(&w_is_const_),
conv_attrs_.activation.c_str());
} else {
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $13 ? "NHWC" : "NCHW",
Expand All @@ -52,7 +59,8 @@ class Conv : public JsKernel {
"kernel_shape" : [ $5, $6 ],
"pads" : [ $7, $8, $9, $10 ],
"strides" : [ $11, $12 ],
"w_is_const" : () JS_ARROW(!!HEAP8[$14])
"w_is_const" : () JS_ARROW(!!HEAP8[$14]),
"activation" : UTF8ToString($15)
}),
static_cast<int32_t>(conv_attrs_.auto_pad),
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
Expand All @@ -67,7 +75,8 @@ class Conv : public JsKernel {
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_));
reinterpret_cast<int32_t>(&w_is_const_),
conv_attrs_.activation.c_str());
}
}

Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/js/operators/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ ONNX_OPERATOR_KERNEL_EX(
11,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
ConvTranspose<true>);
ConvTranspose<true, true>);

ONNX_OPERATOR_KERNEL_EX(
ConvTranspose,
kOnnxDomain,
11,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
ConvTranspose<false>);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
ConvTranspose,
kOnnxDomain,
Expand Down
Loading

0 comments on commit 84f2097

Please sign in to comment.