Skip to content

Commit

Permalink
FP16 Pool & Reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 12, 2023
1 parent db558ef commit 0f4b700
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 217 deletions.
6 changes: 1 addition & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/pool.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
Expand All @@ -22,9 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (inputs[0].dims.length !== 4) {
throw new Error('Pool ops supports 2-D inputs only for now.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};

const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
Expand Down Expand Up @@ -248,7 +244,7 @@ const createAveragePoolProgramInfo =
const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape);

const x = inputVariable('x', input.dataType, input.dims);
const dataType = 'f32';
const dataType = x.type.value;

const op1 = 'value += x_val;';
let op2 = '';
Expand Down
8 changes: 2 additions & 6 deletions js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (inputs.length === 2 && inputs[1].dims.length !== 1) {
throw new Error('Invalid axes input dims.');
}

if (inputs[0].dataType !== DataType.float) {
throw new Error('Invalid input type.');
}
};

export interface ReduceAttributes extends AttributeWithCacheKey {
Expand Down Expand Up @@ -161,7 +157,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes):
export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => {
validateInputs(context.inputs);
const reduceOp: ReduceOp = (input, output) =>
[`var t = f32(0); var value = ${output.type.storage}(0);`,
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += (t * t);`,
'value = sqrt(value);',
Expand Down Expand Up @@ -266,7 +262,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes)
export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => {
validateInputs(context.inputs);
const reduceOp: ReduceOp = (input, output) =>
[`var t = f32(0); var value = ${output.type.storage}(0);`,
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
`t = ${input.getByOffset('inputOffset')}; value += t * t;`,
'',
Expand Down
256 changes: 128 additions & 128 deletions onnxruntime/core/providers/js/js_execution_provider.cc

Large diffs are not rendered by default.

112 changes: 54 additions & 58 deletions onnxruntime/core/providers/js/operators/pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,65 @@
namespace onnxruntime {
namespace js {

#define POOLING_KERNEL(op_name, domain, is_channels_last, data_type, pool_type, since_version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
data_type, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Pool<data_type, pool_type, is_channels_last>);
#define POOLING_KERNEL(op_name, domain, is_channels_last, pool_type, since_version) \
ONNX_OPERATOR_KERNEL_EX( \
op_name, \
domain, \
since_version, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), \
Pool<pool_type, is_channels_last>);

#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
end_version, \
data_type, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()), \
Pool<data_type, pool_type, is_channels_last>);
#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, pool_type, since_version, end_version) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
end_version, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", JsepSupportedFloatTypes()), \
Pool<pool_type, is_channels_last>);

#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
data_type, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, is_channels_last>);
#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version) \
ONNX_OPERATOR_KERNEL_EX( \
op_name, \
domain, \
since_version, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<pool_type, is_channels_last>);

#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
end_version, \
data_type, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<data_type>()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<data_type, pool_type, is_channels_last>);
#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version, end_version) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
op_name, \
domain, \
since_version, \
end_version, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>()), \
Pool<pool_type, is_channels_last>);

POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 10, 10)
POOLING_KERNEL(AveragePool, kOnnxDomain, false, float, AveragePool, 11)
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 11)
POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, float, AveragePool, 1)
POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 1)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9)
POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10)
POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11)
POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11)
POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1)
POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1)

POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, float, MaxPool<1>, 1, 7)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 8, 9)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 10, 10)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 11, 11)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 11, 11)
POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 12)
POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 12)
POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, float, MaxPool<1>, 1)
POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, float, MaxPool<1>, 1)
POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11)
POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11)
POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12)
POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 12)
POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, MaxPool<1>, 1)
POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1)

} // namespace js
} // namespace onnxruntime
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/js/operators/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace js {
#define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"})
#define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast<int32_t>(is_channels_last)

template <typename T, typename PoolType, bool is_channels_last>
template <typename PoolType, bool is_channels_last>
class Pool : public JsKernel, public PoolBase {
public:
Pool(const OpKernelInfo& info) : JsKernel(info), PoolBase(info) {
Expand All @@ -65,10 +65,10 @@ class Pool : public JsKernel, public PoolBase {
}
};

template <typename T, bool is_channels_last>
class Pool<T, MaxPool<8>, is_channels_last> final : public Pool<T, MaxPool<1>, is_channels_last> {
template <bool is_channels_last>
class Pool<MaxPool<8>, is_channels_last> final : public Pool<MaxPool<1>, is_channels_last> {
public:
Pool(const OpKernelInfo& info) : Pool<T, MaxPool<1>, is_channels_last>(info) {}
Pool(const OpKernelInfo& info) : Pool<MaxPool<1>, is_channels_last>(info) {}
};

} // namespace js
Expand Down
28 changes: 13 additions & 15 deletions onnxruntime/core/providers/js/operators/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,30 @@ namespace onnxruntime {
namespace js {

#define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, sinceVersion, endVersion) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
ReduceOp, \
kOnnxDomain, \
sinceVersion, endVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()), \
ReduceOp<float>);
.TypeConstraint("T", JsepSupportedFloatTypes()), \
ReduceOp<true>);

// macro REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL does not set .InputMemoryType(OrtMemTypeCPU, 1), so in future if
// a new opset version update applies to Reduce* operators, we may need to add another macro like
// REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type.
// i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased.

#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
ReduceOp, \
kOnnxDomain, \
sinceVersion, \
float, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
ReduceOp<float>);
#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \
ONNX_OPERATOR_KERNEL_EX( \
ReduceOp, \
kOnnxDomain, \
sinceVersion, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", JsepSupportedFloatTypes()) \
.InputMemoryType(OrtMemTypeCPU, 1), \
ReduceOp<true>);

REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10);
REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace onnxruntime {
namespace js {
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
template <typename T, bool allow_multi_axes = true> \
template <bool allow_multi_axes = true> \
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
Expand Down

0 comments on commit 0f4b700

Please sign in to comment.