Skip to content

Commit

Permalink
Register ScatterElements for opset > 11
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 19, 2024
1 parent df116b8 commit c6ec60f
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 17 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
Expand Down Expand Up @@ -1249,6 +1249,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);

// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
Expand All @@ -1264,6 +1265,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);

// Opset 19
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast);
Expand Down Expand Up @@ -1926,7 +1928,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements)>,

Check warning on line 1931 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L1931

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1931:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax)>,
Expand Down Expand Up @@ -2134,6 +2136,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,

Check warning on line 2139 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_execution_provider.cc#L2139

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2139:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
Expand All @@ -2148,6 +2151,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,

// Opset 19
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast)>,
Expand Down
48 changes: 46 additions & 2 deletions onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@ struct FuncAssignment {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; }
};

template <class T>
struct FuncAdd {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; }
};

template <class T>
struct FuncMul {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; }
};

template <class T>
struct FuncMax {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
if (start_addr[index] < value) {
start_addr[index] = value;
}
}
};

template <class T>
struct FuncMin {
__device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
if (start_addr[index] > value) {
start_addr[index] = value;
}
}
};

template <typename T, typename TIndex, bool IsGather, typename OffsetCalcT, typename TFunc>
__global__ void _GatherScatterElementsKernel(const T* src_data, const TIndex* indices_data, T* output_data,
const int64_t input_dim_along_axis, const int64_t input_stride_along_axis,
Expand Down Expand Up @@ -238,8 +266,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con
template <typename T, typename TIndex>
Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data,
T* output_data, const GatherScatterElementsArgs& args) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAssignment<T>());
if (args.operation == GatherScatterElementsArgs::Operation::NONE) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAssignment<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::ADD) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncAdd<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MUL) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMul<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MAX) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMax<T>());
} else if (args.operation == GatherScatterElementsArgs::Operation::MIN) {
return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
FuncMin<T>());
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator.");
}
}

#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ namespace onnxruntime {
namespace cuda {

struct GatherScatterElementsArgs {
enum class Operation {
NONE,
ADD,
MUL,
MAX,
MIN
};

int64_t rank;
int64_t axis;
int64_t input_size;
Expand All @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs {
TArray<fast_divmod> indices_fdms;
TArray<int64_t> indices_strides;
int64_t indices_size;
// operation used to combine values associated the same
// memory location in the output tensor.
Operation operation;
};

template <typename T, typename TIndex>
Expand Down
32 changes: 31 additions & 1 deletion onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);

ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider,
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind",
std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
ScatterElements);

ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
Expand Down Expand Up @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector();
CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args);

if (reduction_ == "none") {
args.operation = GatherScatterElementsArgs::Operation::NONE;
} else if (reduction_ == "add") {
args.operation = GatherScatterElementsArgs::Operation::ADD;
} else if (reduction_ == "mul") {
args.operation = GatherScatterElementsArgs::Operation::MUL;
} else if (reduction_ == "min") {
args.operation = GatherScatterElementsArgs::Operation::MIN;
} else if (reduction_ == "max") {
args.operation = GatherScatterElementsArgs::Operation::MAX;
} else {
ORT_THROW("Unsupported reduction type");
}

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/scatter_elements.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel {
ScatterElements(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(info.GetAttr<int64_t>("axis", &axis_).IsOK(),
"Missing/Invalid 'axis' attribute value");
reduction_ = info.GetAttrOrDefault<std::string>("reduction", "none");

ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" ||
reduction_ == "mul" || reduction_ == "max" ||
reduction_ == "min",
"Invalid reduction attribute value of ", reduction_);
}
~ScatterElements() = default;
Status ComputeInternal(OpKernelContext* context) const override;
Expand All @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel {
struct ComputeImpl;

int64_t axis_;
// "reduction" attribute has been defined since opset 13 but
// we never implemented it. Let's try to support them starting
// with opset 18.
std::string reduction_;

Check warning on line 35 in onnxruntime/core/providers/cuda/tensor/scatter_elements.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/scatter_elements.h#L35

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/cuda/tensor/scatter_elements.h:35:  Add #include <string> for string  [build/include_what_you_use] [4]
};

} // namespace cuda
Expand Down
26 changes: 14 additions & 12 deletions onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,20 @@ void RunTest(const std::vector<int64_t>& input_dims, const std::vector<int64_t>&
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
}

onnxruntime::test::OpTester test1("ScatterElements", 11);
if (has_axis) test1.AddAttribute<int64_t>("axis", axis);
test1.AddInput<T>("data", input_dims, input_data);
test1.AddInput<TIndex>("indices", indices_dims, indices_data);
test1.AddInput<T>("updates", indices_dims, updates_data);
test1.AddOutput<T>("y", input_dims, output_data);
if (std::is_same<T, int8_t>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
} else {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
for (int opset : {11, 18}) {
onnxruntime::test::OpTester test1("ScatterElements", opset);
if (has_axis) test1.AddAttribute<int64_t>("axis", axis);
test1.AddInput<T>("data", input_dims, input_data);
test1.AddInput<TIndex>("indices", indices_dims, indices_data);
test1.AddInput<T>("updates", indices_dims, updates_data);
test1.AddOutput<T>("y", input_dims, output_data);
if (std::is_same<T, int8_t>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
} else if (std::is_same<T, MLFloat16>::value) {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
} else {
test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider});
}
}
}

Expand Down

0 comments on commit c6ec60f

Please sign in to comment.