diff --git a/js/web/test/data/ops/max-pool.jsonc b/js/web/test/data/ops/max-pool.jsonc new file mode 100644 index 0000000000000..e485f48e93eb4 --- /dev/null +++ b/js/web/test/data/ops/max-pool.jsonc @@ -0,0 +1,67 @@ +[ + { + "name": "MaxPool", + "operator": "MaxPool", + "attributes": [ + { "name": "kernel_shape", "data": [3], "type": "ints" }, + { "name": "dilations", "data": [1], "type": "ints" } + ], + "cases": [ + { + "name": "T[3,5,5] T[3,5,3]", + "inputs": [ + { + "data": [ + 1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238, + -0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334, + 0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876, + 0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989, + -2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197, 2.269754648208618, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, + 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100 + ], + "dims": [3, 5, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.764052391052246, 2.2408931255340576, 2.2408931255340576, 0.9500884413719177, 0.9500884413719177, + 0.4105985164642334, 1.4542734622955322, 1.4542734622955322, 0.7610377073287964, 1.4940791130065918, + 1.4940791130065918, 0.3130677044391632, 0.8644362092018127, 0.8644362092018127, 2.269754648208618, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, + 100, 100 + ], + "dims": [3, 5, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MaxPool", + "operator": "MaxPool", + "attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }], + "cases": [ + { + "name": "T[1,1,5] T[1,1,3]", + "inputs": [ + { + "data": [1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238], + "dims": [1, 1, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.764052391052246, 2.2408931255340576, 2.2408931255340576], + "dims": [1, 1, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index ede89f7557dd8..44b89142790ab 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1371,6 +1371,7 @@ "matmul.jsonc", "matmulnbits.jsonc", "matmul-broadcast.jsonc", + "max-pool.jsonc", "mul.jsonc", "mul_int32.jsonc", "multihead-attention.jsonc", diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 32e8e1facafcd..0357c2f02a7a2 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -48,7 +48,6 @@ class ConvBase : public JsKernel { std::vector activation_params = info.GetAttrsOrDefault("activation_params"); int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault("channels_last", 0); - // currently only support Conv 1D/2D. TODO: support Conv3D and other JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, @@ -65,8 +64,8 @@ class ConvBase : public JsKernel { JSEP_HEAP32_INDEX_START(dilations), JSEP_HEAP32_INDEX_END(dilations), static_cast(conv_attrs_.group), - JSEP_HEAP32_INDEX_START(kernel_shape), - JSEP_HEAP32_INDEX_END(kernel_shape), + JSEP_HEAP32_INDEX_START(kernel_shapes), + JSEP_HEAP32_INDEX_END(kernel_shapes), JSEP_HEAP32_INDEX_START(local_pads), JSEP_HEAP32_INDEX_END(local_pads), JSEP_HEAP32_INDEX_START(strides), diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 5723123c0c3b8..66bcde86020b6 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -9,38 +9,45 @@ namespace onnxruntime { namespace js { -#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ - "format" : $15 ? "NHWC" : "NCHW", \ - "auto_pad" : $1, \ - "ceil_mode" : $2, \ - "count_include_pad" : $3, \ - "storage_order" : $4, \ - "dilations" : [ $5, $6 ], \ - "kernel_shape" : [ $7, $8 ], \ - "pads" : [ $9, $10, $11, $12 ], \ - "strides" : [ $13, $14 ] \ +#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ + "format" : $13 ? "NHWC" : "NCHW", \ + "auto_pad" : $1, \ + "ceil_mode" : $2, \ + "count_include_pad" : $3, \ + "storage_order" : $4, \ + "dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \ + "kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \ + "pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \ + "strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \ }) -#define POOL_ATTRIBUTES_PARAM_LIST \ - static_cast(pool_attrs_.auto_pad), \ - static_cast(pool_attrs_.ceil_mode), \ - static_cast(pool_attrs_.count_include_pad), \ - static_cast(pool_attrs_.storage_order), \ - static_cast(pool_attrs_.dilations.size() > 0 ? pool_attrs_.dilations[0] : 0), \ - static_cast(pool_attrs_.dilations.size() > 1 ? pool_attrs_.dilations[1] : 0), \ - static_cast(pool_attrs_.kernel_shape.size() > 0 ? pool_attrs_.kernel_shape[0] : 0), \ - static_cast(pool_attrs_.kernel_shape.size() > 1 ? pool_attrs_.kernel_shape[1] : 0), \ - static_cast(pool_attrs_.pads.size() > 0 ? pool_attrs_.pads[0] : 0), \ - static_cast(pool_attrs_.pads.size() > 1 ? pool_attrs_.pads[1] : 0), \ - static_cast(pool_attrs_.pads.size() > 2 ? pool_attrs_.pads[2] : 0), \ - static_cast(pool_attrs_.pads.size() > 3 ? pool_attrs_.pads[3] : 0), \ - static_cast(pool_attrs_.strides.size() > 0 ? pool_attrs_.strides[0] : 0), \ - static_cast(pool_attrs_.strides.size() > 1 ? pool_attrs_.strides[1] : 0), \ +#define POOL_ATTRIBUTES_PARAM_LIST \ + static_cast(pool_attrs_.auto_pad), \ + static_cast(pool_attrs_.ceil_mode), \ + static_cast(pool_attrs_.count_include_pad), \ + static_cast(pool_attrs_.storage_order), \ + JSEP_HEAP32_INDEX_START(dilations), \ + JSEP_HEAP32_INDEX_END(dilations), \ + JSEP_HEAP32_INDEX_START(kernel_shapes), \ + JSEP_HEAP32_INDEX_END(kernel_shapes), \ + JSEP_HEAP32_INDEX_START(pads), \ + JSEP_HEAP32_INDEX_END(pads), \ + JSEP_HEAP32_INDEX_START(strides), \ + JSEP_HEAP32_INDEX_END(strides), \ static_cast(is_channels_last) #define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"}) #define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast(is_channels_last) +template +inline const std::vector CastTensorShapeVector(const TensorShapeVector& shape) { + std::vector castedShapes(shape.size(), 0); + for (size_t i = 0; i < shape.size(); ++i) { + castedShapes[i] = gsl::narrow_cast(shape[i]); + } + return castedShapes; +} + template class Pool : public JsKernel, public PoolBase { public: @@ -54,6 +61,10 @@ class Pool : public JsKernel, public PoolBase { // TODO: GlobalLpPool } } else { + auto kernel_shapes{CastTensorShapeVector(pool_attrs_.kernel_shape)}; + auto strides{CastTensorShapeVector(pool_attrs_.strides)}; + auto dilations{CastTensorShapeVector(pool_attrs_.dilations)}; + auto pads{CastTensorShapeVector(pool_attrs_.pads)}; if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) { JSEP_INIT_KERNEL_ATTRIBUTE(AveragePool, POOL_ATTRIBUTES_JS_OBJ_MAPPING, POOL_ATTRIBUTES_PARAM_LIST); } else if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) {