From 9ecab2966fbe5d7cf8396ca3a34c3e97de187247 Mon Sep 17 00:00:00 2001 From: Xinghua Cao Date: Fri, 4 Aug 2023 15:31:09 +0800 Subject: [PATCH] [JS/WebGPU] support Concat.int32 operator --- js/web/test/data/ops/concat_int32.jsonc | 406 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 27 +- .../core/providers/js/operators/concat.cc | 12 +- 3 files changed, 428 insertions(+), 17 deletions(-) create mode 100644 js/web/test/data/ops/concat_int32.jsonc diff --git a/js/web/test/data/ops/concat_int32.jsonc b/js/web/test/data/ops/concat_int32.jsonc new file mode 100644 index 0000000000000..6e2ce18c6f7c5 --- /dev/null +++ b/js/web/test/data/ops/concat_int32.jsonc @@ -0,0 +1,406 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[4,4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16 + ], + "dims": [8, 4], + "type": "int32" + } + ] + }, + { + "name": "[2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [4, 4], + "type": "int32" + } + ] + }, + { + "name": "[2,3]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6], + "dims": [4, 3], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 2D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[4,4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 9, 10, 11, 12, 13, 14, 15, 16, 13, 14, 15, + 16 + ], + "dims": [4, 8], + "type": "int32" + } + ] + }, + { + "name": "[2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8], + "dims": [2, 4], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8], + "dims": [2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 5, 6, 1, 2, 3, 4, 3, 4, 7, 8, 5, 6, 7, 8], + "dims": [2, 8], + "type": "int32" + } + ] + }, + { + "name": "[2,3]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6], + "dims": [2, 6], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, + 16 + ], + "dims": [4, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 9, 10, 13, 14, 11, 12, 15, + 16 + ], + "dims": [2, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 3D axis=2", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 2, "type": "int" }], + "cases": [ + { + "name": "[2,2,4]", + "inputs": [ + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + }, + { + "data": [1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16], + "dims": [2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 9, 10, 13, 14, 9, 10, 13, 14, 11, 12, 15, 16, 11, 12, 15, + 16 + ], + "dims": [2, 2, 8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 0, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, + 29, 30, 27, 28, 31, 32 + ], + "dims": [4, 2, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=1", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 1, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, + 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 17, 18, 21, 22, 19, 20, 23, 24, 25, + 26, 29, 30, 27, 28, 31, 32 + ], + "dims": [2, 4, 2, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=2", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 2, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 9, 10, 13, 14, 11, 12, 15, + 16, 17, 18, 21, 22, 19, 20, 23, 24, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, 28, 31, 32, 25, + 26, 29, 30, 27, 28, 31, 32 + ], + "dims": [2, 2, 4, 4], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Concat 4D axis=3", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": 3, "type": "int" }], + "cases": [ + { + "name": "[2,2,2,4]", + "inputs": [ + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + }, + { + "data": [ + 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16, 17, 18, 21, 22, 19, 20, 23, 24, 25, 26, 29, 30, 27, + 28, 31, 32 + ], + "dims": [2, 2, 2, 4], + "type": "int32" + } + ], + "outputs": [ + { + "data": [ + 1, 2, 5, 6, 1, 2, 5, 6, 3, 4, 7, 8, 3, 4, 7, 8, 9, 10, 13, 14, 9, 10, 13, 14, 11, 12, 15, 16, 11, 12, 15, + 16, 17, 18, 21, 22, 17, 18, 21, 22, 19, 20, 23, 24, 19, 20, 23, 24, 25, 26, 29, 30, 25, 26, 29, 30, 27, + 28, 31, 32, 27, 28, 31, 32 + ], + "dims": [2, 2, 2, 8], + "type": "int32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index aca3526115c7e..5bd264b16349e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -432,18 +432,18 @@ // // "test_compress_1", // // "test_compress_default_axis", // // "test_compress_negative_axis", - // "test_concat_1d_axis_0", - // "test_concat_1d_axis_negative_1", - // "test_concat_2d_axis_0", - // "test_concat_2d_axis_1", - // "test_concat_2d_axis_negative_1", - // "test_concat_2d_axis_negative_2", - // "test_concat_3d_axis_0", - // "test_concat_3d_axis_1", - // "test_concat_3d_axis_2", - // "test_concat_3d_axis_negative_1", - // "test_concat_3d_axis_negative_2", - // "test_concat_3d_axis_negative_3", + "test_concat_1d_axis_0", + "test_concat_1d_axis_negative_1", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_2d_axis_negative_1", + "test_concat_2d_axis_negative_2", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", + "test_concat_3d_axis_negative_1", + "test_concat_3d_axis_negative_2", + "test_concat_3d_axis_negative_3", "test_conv_with_autopad_same", "test_conv_with_strides_and_asymmetric_padding", "test_conv_with_strides_no_padding", @@ -1329,7 +1329,8 @@ //"and.jsonc", "asin.jsonc", "ceil.jsonc", - //"concat.jsonc", + "concat.jsonc", + "concat_int32.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", diff --git a/onnxruntime/core/providers/js/operators/concat.cc b/onnxruntime/core/providers/js/operators/concat.cc index 7d50d78c82851..3a6a7e1cafd7a 100644 --- a/onnxruntime/core/providers/js/operators/concat.cc +++ b/onnxruntime/core/providers/js/operators/concat.cc @@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 1, 3, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -21,7 +22,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 4, 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_VERSIONED_KERNEL_EX( @@ -30,7 +32,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 11, 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); ONNX_OPERATOR_KERNEL_EX( @@ -39,7 +42,8 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", DataTypeImpl::GetTensorType()), + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), Concat); } // namespace js