Skip to content

Commit

Permalink
[js/webgpu] Fix activation_params in FusedConv
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Dec 21, 2023
1 parent ffa6602 commit 617608b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
68 changes: 68 additions & 0 deletions js/web/test/data/ops/fused-conv.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,73 @@
]
}
]
},
{
"name": "fused conv with clip",
"operator": "FusedConv",
"attributes": [
{ "name": "activation", "data": "Clip", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "activation_params", "data": [0.0, 600.0], "type": "floats" }
],
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [370, 470, 600, 600],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "fused conv with clip",
"operator": "FusedConv",
"attributes": [
{ "name": "activation", "data": "Clip", "type": "string" },
{ "name": "kernel_shape", "data": [2, 2], "type": "ints" },
{ "name": "activation_params", "data": [400.0, 600.0], "type": "floats" }
],
"opset": { "domain": "com.microsoft", "version": 1 },
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
"dims": [1, 1, 3, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [400, 470, 600, 600],
"dims": [1, 1, 2, 2],
"type": "float32"
}
]
}
]
}
]
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class ConvBase : public JsKernel {
}
if (is_fused_conv) {
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &conv_attrs_.activation));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("activation_params", activation_params));
if (conv_attrs_.activation == "Clip") {
ORT_THROW_IF_ERROR(info.GetAttrs<float>("activation_params", activation_params));
}
} else {
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
activation_params = info.GetAttrsOrDefault<float>("activation_params", activation_params);
Expand Down

0 comments on commit 617608b

Please sign in to comment.