diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index d8a0792209b0f..7fd15188b1868 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1041,7 +1041,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1249,6 +1249,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1264,6 +1265,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -1926,7 +1928,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2134,6 +2136,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2148,6 +2151,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 10c8625b39ef8..4dacceb6e6af7 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -98,6 +98,34 @@ struct FuncAssignment { __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } }; +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + if (start_addr[index] < value) { + start_addr[index] = value; + } + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + if (start_addr[index] > value) { + start_addr[index] = value; + } + } +}; + template __global__ void _GatherScatterElementsKernel(const T* src_data, const TIndex* indices_data, T* output_data, const int64_t input_dim_along_axis, const int64_t input_stride_along_axis, @@ -238,8 +266,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + if (args.operation == GatherScatterElementsArgs::Operation::NONE) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); + } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + } } #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h index 631d0bf049c6f..7b1c88f1fc1cb 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h @@ -10,6 +10,14 @@ namespace onnxruntime { namespace cuda { struct GatherScatterElementsArgs { + enum class Operation { + NONE, + ADD, + MUL, + MAX, + MIN + }; + int64_t rank; int64_t axis; int64_t input_size; @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs { TArray indices_fdms; TArray indices_strides; int64_t indices_size; + // operation used to combine values associated the same + // memory location in the output tensor. + Operation operation; }; template diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index e4d145154971e..42a9f50001103 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe DataTypeImpl::GetTensorType()}), ScatterElements); -ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + if (reduction_ == "none") { + args.operation = GatherScatterElementsArgs::Operation::NONE; + } else if (reduction_ == "add") { + args.operation = GatherScatterElementsArgs::Operation::ADD; + } else if (reduction_ == "mul") { + args.operation = GatherScatterElementsArgs::Operation::MUL; + } else if (reduction_ == "min") { + args.operation = GatherScatterElementsArgs::Operation::MIN; + } else if (reduction_ == "max") { + args.operation = GatherScatterElementsArgs::Operation::MAX; + } else { + ORT_THROW("Unsupported reduction type"); + } + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. int dtype = GetElementType(input_tensor->DataType()->Size()); if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index 3e9e0ce041845..3884b716da308 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel { ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + reduction_ = info.GetAttrOrDefault("reduction", "none"); + + ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || + reduction_ == "mul" || reduction_ == "max" || + reduction_ == "min", + "Invalid reduction attribute value of ", reduction_); } ~ScatterElements() = default; Status ComputeInternal(OpKernelContext* context) const override; @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel { struct ComputeImpl; int64_t axis_; + // "reduction" attribute has been defined since opset 13 but + // we never implemented it. Let's try to support them starting + // with opset 18. + std::string reduction_; }; } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 9b44bf400c05e..81dd306f6bff1 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -75,18 +75,20 @@ void RunTest(const std::vector& input_dims, const std::vector& test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } - onnxruntime::test::OpTester test1("ScatterElements", 11); - if (has_axis) test1.AddAttribute("axis", axis); - test1.AddInput("data", input_dims, input_data); - test1.AddInput("indices", indices_dims, indices_data); - test1.AddInput("updates", indices_dims, updates_data); - test1.AddOutput("y", input_dims, output_data); - if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); - } else if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } else { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + for (int opset : {11, 18}) { + onnxruntime::test::OpTester test1("ScatterElements", opset); + if (has_axis) test1.AddAttribute("axis", axis); + test1.AddInput("data", input_dims, input_data); + test1.AddInput("indices", indices_dims, indices_data); + test1.AddInput("updates", indices_dims, updates_data); + test1.AddOutput("y", input_dims, output_data); + if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + } else if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } else { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } } }