From f3b332890a2fcd79a9dae3841700574957261168 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 18 Jan 2024 16:18:21 -0800 Subject: [PATCH 01/17] Register ScatterElements for opset > 11 --- .../providers/cuda/cuda_execution_provider.cc | 8 +++- .../cuda/tensor/gather_elements_impl.cu | 48 ++++++++++++++++++- .../cuda/tensor/gather_elements_impl.h | 11 +++++ .../providers/cuda/tensor/scatter_elements.cc | 32 ++++++++++++- .../providers/cuda/tensor/scatter_elements.h | 10 ++++ .../providers/cpu/tensor/scatter_op_test.cc | 26 +++++----- 6 files changed, 118 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3fc4ed355a12b..77e682e05a2a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); @@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 10c8625b39ef8..4dacceb6e6af7 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -98,6 +98,34 @@ struct FuncAssignment { __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } }; +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + if (start_addr[index] < value) { + start_addr[index] = value; + } + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + if (start_addr[index] > value) { + start_addr[index] = value; + } + } +}; + template __global__ void _GatherScatterElementsKernel(const T* src_data, const TIndex* indices_data, T* output_data, const int64_t input_dim_along_axis, const int64_t input_stride_along_axis, @@ -238,8 +266,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + if (args.operation == GatherScatterElementsArgs::Operation::NONE) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); + } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + } } #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h index 631d0bf049c6f..7b1c88f1fc1cb 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h @@ -10,6 +10,14 @@ namespace onnxruntime { namespace cuda { struct GatherScatterElementsArgs { + enum class Operation { + NONE, + ADD, + MUL, + MAX, + MIN + }; + int64_t rank; int64_t axis; int64_t input_size; @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs { TArray indices_fdms; TArray indices_strides; int64_t indices_size; + // operation used to combine values associated the same + // memory location in the output tensor. + Operation operation; }; template diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index e4d145154971e..42a9f50001103 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe DataTypeImpl::GetTensorType()}), ScatterElements); -ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + if (reduction_ == "none") { + args.operation = GatherScatterElementsArgs::Operation::NONE; + } else if (reduction_ == "add") { + args.operation = GatherScatterElementsArgs::Operation::ADD; + } else if (reduction_ == "mul") { + args.operation = GatherScatterElementsArgs::Operation::MUL; + } else if (reduction_ == "min") { + args.operation = GatherScatterElementsArgs::Operation::MIN; + } else if (reduction_ == "max") { + args.operation = GatherScatterElementsArgs::Operation::MAX; + } else { + ORT_THROW("Unsupported reduction type"); + } + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. int dtype = GetElementType(input_tensor->DataType()->Size()); if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index 3e9e0ce041845..3884b716da308 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel { ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + reduction_ = info.GetAttrOrDefault("reduction", "none"); + + ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || + reduction_ == "mul" || reduction_ == "max" || + reduction_ == "min", + "Invalid reduction attribute value of ", reduction_); } ~ScatterElements() = default; Status ComputeInternal(OpKernelContext* context) const override; @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel { struct ComputeImpl; int64_t axis_; + // "reduction" attribute has been defined since opset 13 but + // we never implemented it. Let's try to support them starting + // with opset 18. + std::string reduction_; }; } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 9b44bf400c05e..81dd306f6bff1 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -75,18 +75,20 @@ void RunTest(const std::vector& input_dims, const std::vector& test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } - onnxruntime::test::OpTester test1("ScatterElements", 11); - if (has_axis) test1.AddAttribute("axis", axis); - test1.AddInput("data", input_dims, input_data); - test1.AddInput("indices", indices_dims, indices_data); - test1.AddInput("updates", indices_dims, updates_data); - test1.AddOutput("y", input_dims, output_data); - if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); - } else if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } else { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + for (int opset : {11, 18}) { + onnxruntime::test::OpTester test1("ScatterElements", opset); + if (has_axis) test1.AddAttribute("axis", axis); + test1.AddInput("data", input_dims, input_data); + test1.AddInput("indices", indices_dims, indices_data); + test1.AddInput("updates", indices_dims, updates_data); + test1.AddOutput("y", input_dims, output_data); + if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + } else if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } else { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } } } From 1e9a0a72246880320bfe6581e68761c121354d9e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 25 Jan 2024 19:47:34 -0800 Subject: [PATCH 02/17] Fix atomic --- .../core/providers/cuda/atomic/common.cuh | 59 ++++++++++ .../cuda/tensor/gather_elements_impl.cu | 4 +- .../providers/cpu/tensor/scatter_op_test.cc | 106 +++++++++++++++--- 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 14fa2d0706f73..c8d0b82a8502f 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -25,6 +25,16 @@ namespace onnxruntime { namespace cuda { +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) { + int* address_as_int = reinterpret_cast(address); + int old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + static_cast(val) + assumed); + } while (assumed != old); +} + __device__ __forceinline__ void atomic_add(float *address, float value) { atomicAdd(address, value); } @@ -122,5 +132,54 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +__device__ __forceinline__ void atomic_mul(half* address, half val) { + unsigned short int* address_as_short = reinterpret_cast(address); + unsigned short int old = *address_as_short, assumed; + do { + assumed = old; + old = atomicCAS(address_as_short, assumed, + __half_as_short(val * __short_as_half(assumed))); + } while (assumed != old); +} + +__device__ __forceinline__ void atomic_mul(float* address, float val) { + int* address_as_int = reinterpret_cast(address); + int old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + __float_as_int(val * __int_as_float(assumed))); + } while (assumed != old); +} + +__device__ __forceinline__ void atomic_mul(double* address, double val) { + unsigned long long int* address_as_long_long = reinterpret_cast(address); + unsigned long long int old = *address_as_long_long, assumed; + do { + assumed = old; + old = atomicCAS(address_as_long_long, assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); + } while (assumed != old); +} + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(val * static_cast(old_byte)); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 4dacceb6e6af7..99c4bc671d000 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -100,12 +100,12 @@ struct FuncAssignment { template struct FuncAdd { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_add(start_addr + index, value); } }; template struct FuncMul { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_mul(start_addr + index, value); } }; template diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 81dd306f6bff1..8f93fd78e5537 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -75,20 +75,18 @@ void RunTest(const std::vector& input_dims, const std::vector& test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } - for (int opset : {11, 18}) { - onnxruntime::test::OpTester test1("ScatterElements", opset); - if (has_axis) test1.AddAttribute("axis", axis); - test1.AddInput("data", input_dims, input_data); - test1.AddInput("indices", indices_dims, indices_data); - test1.AddInput("updates", indices_dims, updates_data); - test1.AddOutput("y", input_dims, output_data); - if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); - } else if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } else { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } + onnxruntime::test::OpTester test1("ScatterElements", 11); + if (has_axis) test1.AddAttribute("axis", axis); + test1.AddInput("data", input_dims, input_data); + test1.AddInput("indices", indices_dims, indices_data); + test1.AddInput("updates", indices_dims, updates_data); + test1.AddOutput("y", input_dims, output_data); + if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + } else if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } else { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } } @@ -304,5 +302,85 @@ TEST(Scatter, BoolInputWithAxis) { scatter_bool_with_axis_tests("ScatterElements", 11); } +TEST(ScatterElements, AddReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "add"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, AddReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "add"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MulReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "mul"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MulReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "mul"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MinReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + } // namespace test } // namespace onnxruntime From 5274efe16337ebe33d7b584df2f7b0e3b530fd32 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 09:59:34 -0800 Subject: [PATCH 03/17] documents --- .../core/providers/cuda/atomic/common.cuh | 112 +++++++++++++++--- 1 file changed, 97 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index c8d0b82a8502f..a7520558922ed 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -164,21 +164,103 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { // Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { - size_t offset = (size_t)address & 3; \ - uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ - uint32_t old = *address_as_ui; \ - uint32_t shift = offset * 8; \ - uint32_t old_byte; \ - uint32_t newval; \ - uint32_t assumed; \ - \ - do { \ - assumed = old; \ - old_byte = (old >> shift) & 0xff; \ - newval = static_cast(val * static_cast(old_byte)); \ - newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ - old = atomicCAS(address_as_ui, assumed, newval); \ - } while (assumed != old); \ + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & 0xff; + // Use + for atomic addition, * for atomic multiplication, / for atomic division. + newval = static_cast(val * static_cast(old_byte)); + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); } } // namespace cuda From 469245e0b62631f9bc93873605b4e54ca0018016 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 12:48:13 -0800 Subject: [PATCH 04/17] Implement int8_t using template function --- .../core/providers/cuda/atomic/common.cuh | 70 +++++++++++++++---- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index a7520558922ed..955c583bf8b54 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -25,16 +25,6 @@ namespace onnxruntime { namespace cuda { -__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) { - int* address_as_int = reinterpret_cast(address); - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - static_cast(val) + assumed); - } while (assumed != old); -} - __device__ __forceinline__ void atomic_add(float *address, float value) { atomicAdd(address, value); } @@ -162,8 +152,18 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { } while (assumed != old); } -// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. -__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// +// E.g., Making ByteLargeType int8_t and BinaryFunc +// struct AddFunc { +// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const { +// return a + b; +// } +// makes this function atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) { // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(val * static_cast(old_byte)); // Journey of a 32-bit value (cont'd): // // old @@ -263,5 +263,49 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { } while (assumed != old); } +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a > b ? a : b; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? a : b; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + } // namespace cuda } // namespace onnxruntime From ce51bc6278fbe187f14fe373d24907bad0ddd9fd Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 12:49:57 -0800 Subject: [PATCH 05/17] Rename a type --- onnxruntime/core/providers/cuda/atomic/common.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 955c583bf8b54..bb15882745bb3 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -156,14 +156,14 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. // -// E.g., Making ByteLargeType int8_t and BinaryFunc +// E.g., Making OneByteType int8_t and BinaryFunc // struct AddFunc { -// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const { +// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const { // return a + b; // } // makes this function atomic_add for int8_t. -template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) { +template +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val) { // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* a // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(val * static_cast(old_byte)); // Journey of a 32-bit value (cont'd): // // old @@ -280,14 +280,14 @@ struct MulFunc { struct MaxFunc { template __device__ __forceinline__ T operator()(T a, T b) const { - return a > b ? a : b; + return b > a ? b : a; } }; struct MinFunc { template __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? a : b; + return b < a ? b : a; } }; From 9ed1c5e74f1a11f1d469a274993427bdf3c68c21 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 15:41:28 -0800 Subject: [PATCH 06/17] General impl for atomic binary ops --- .../core/providers/cuda/atomic/common.cuh | 158 +++++++++++++----- 1 file changed, 119 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index bb15882745bb3..4f5223a59543a 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,39 +122,11 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } -__device__ __forceinline__ void atomic_mul(half* address, half val) { - unsigned short int* address_as_short = reinterpret_cast(address); - unsigned short int old = *address_as_short, assumed; - do { - assumed = old; - old = atomicCAS(address_as_short, assumed, - __half_as_short(val * __short_as_half(assumed))); - } while (assumed != old); -} - -__device__ __forceinline__ void atomic_mul(float* address, float val) { - int* address_as_int = reinterpret_cast(address); - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(val * __int_as_float(assumed))); - } while (assumed != old); -} - -__device__ __forceinline__ void atomic_mul(double* address, double val) { - unsigned long long int* address_as_long_long = reinterpret_cast(address); - unsigned long long int old = *address_as_long_long, assumed; - do { - assumed = old; - old = atomicCAS(address_as_long_long, assumed, - __double_as_longlong(val * __longlong_as_double(assumed))); - } while (assumed != old); -} - // Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). // // E.g., Making OneByteType int8_t and BinaryFunc // struct AddFunc { @@ -163,7 +135,9 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { // } // makes this function atomic_add for int8_t. template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val) { +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val, BinaryFunc func) { + static_assert(sizeof(OneByteType) == 1, "OneByteType must be 1 byte for the following bit-level manipulations."); + // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +209,7 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(func(val, static_cast(old_byte))); // Journey of a 32-bit value (cont'd): // // old @@ -263,6 +237,84 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add } while (assumed != old); } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +template<> +class AtomicCasType { + public: + using type = int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +// It accumulates `val` into the `address` using the `func`. +// This function function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != new_value); +} + struct AddFunc { template __device__ __forceinline__ T operator()(T a, T b) const { @@ -292,20 +344,48 @@ struct MinFunc { }; __device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, AddFunc()); } - __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MulFunc()); } - __device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MaxFunc()); } - __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(half* address, half value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(half* address, half value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); } +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + } // namespace cuda } // namespace onnxruntime From 9a59ff4cd1eb57d4345dd6c57087c99847c61727 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 16:28:55 -0800 Subject: [PATCH 07/17] Changes: 1. Add more tests for ScatterElements with max/min reduction 2. Copy atomic binary function to RCOM --- .../core/providers/cpu/tensor/scatter.cc | 14 - .../core/providers/cuda/atomic/common.cuh | 2 +- .../cuda/tensor/gather_elements_impl.cu | 20 +- .../core/providers/rocm/atomic/common.cuh | 264 ++++++++++++++++++ .../providers/cpu/tensor/scatter_op_test.cc | 56 +++- 5 files changed, 330 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 8844b7e7a26c4..c7a2005924836 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -198,13 +198,6 @@ struct Func_Min { } }; -template <> -struct Func_Min { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); - } -}; - template <> struct Func_Min { void operator()(BFloat16*, const BFloat16*) const { @@ -233,13 +226,6 @@ struct Func_Max { } }; -template <> -struct Func_Max { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); - } -}; - template <> struct Func_Max { void operator()(BFloat16*, const BFloat16*) const { diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 4f5223a59543a..cc57c2ae16181 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -274,7 +274,7 @@ class AtomicCasType { }; // It accumulates `val` into the `address` using the `func`. -// This function function is thread-safe (i.e., atomic). +// This function is thread-safe (i.e., atomic). template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { ValueType observed = *address, new_value; diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 99c4bc671d000..b710e8a1b48c2 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -95,34 +95,36 @@ struct OffsetCalculatorFor2D { template struct FuncAssignment { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + start_addr[index] = value; + } }; template struct FuncAdd { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_add(start_addr + index, value); } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_add(start_addr + index, value); + } }; template struct FuncMul { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_mul(start_addr + index, value); } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_mul(start_addr + index, value); + } }; template struct FuncMax { __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { - if (start_addr[index] < value) { - start_addr[index] = value; - } + atomic_max(start_addr + index, value); } }; template struct FuncMin { __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { - if (start_addr[index] > value) { - start_addr[index] = value; - } + atomic_min(start_addr + index, value); } }; diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index 4e235702028c6..5f51034231164 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -59,5 +59,269 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz atomic_add(start_addr + index, value); } +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Making OneByteType int8_t and BinaryFunc +// struct AddFunc { +// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const { +// return a + b; +// } +// makes this function atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val, BinaryFunc func) { + static_assert(sizeof(OneByteType) == 1, "OneByteType must be 1 byte for the following bit-level manipulations."); + + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & 0xff; + // Use + for atomic addition, * for atomic multiplication, / for atomic division. + newval = static_cast(func(val, static_cast(old_byte))); + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(0x000000ff << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +template<> +class AtomicCasType { + public: + using type = int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != new_value); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(half* address, half value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(half* address, half value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 8f93fd78e5537..94a02c08281a8 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -356,7 +356,20 @@ TEST(ScatterElements, MulReductionAxis1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } -TEST(ScatterElements, MaxReduction) { +TEST(ScatterElements, MaxReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_Float) { OpTester test("ScatterElements", 18); test.AddAttribute("axis", 0); test.AddAttribute("reduction", "max"); @@ -369,7 +382,33 @@ TEST(ScatterElements, MaxReduction) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } -TEST(ScatterElements, MinReduction) { +TEST(ScatterElements, MaxReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_Float) { OpTester test("ScatterElements", 18); test.AddAttribute("axis", 0); test.AddAttribute("reduction", "min"); @@ -382,5 +421,18 @@ TEST(ScatterElements, MinReduction) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); } +TEST(ScatterElements, MinReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + } // namespace test } // namespace onnxruntime From 42eac4f2d607a098b7021435cb50dd93e412bc29 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 19:18:20 -0800 Subject: [PATCH 08/17] Implement for old cuda --- .../core/providers/cuda/atomic/common.cuh | 174 ++++++++++++------ 1 file changed, 119 insertions(+), 55 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index cc57c2ae16181..f3ec49b6e6458 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,22 +122,71 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + // Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. // It accumulate `val` into the `address` using the `func`. // The accumulation is atomic (i.e., thread-safe). // -// E.g., Making OneByteType int8_t and BinaryFunc -// struct AddFunc { -// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const { -// return a + b; -// } -// makes this function atomic_add for int8_t. -template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val, BinaryFunc func) { - static_assert(sizeof(OneByteType) == 1, "OneByteType must be 1 byte for the following bit-level manipulations."); - +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, "ValueType must be 1 byte for the following bit-level manipulations."); // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -207,9 +256,9 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // V // // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... - old_byte = (old >> shift) & 0xff; + old_byte = (old >> shift) & AtomicCasType::mask; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(func(val, static_cast(old_byte))); + newval = static_cast(func(val, static_cast(old_byte))); // Journey of a 32-bit value (cont'd): // // old @@ -232,56 +281,23 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // // (old & ~(0x000000ff << shift)) | (newval << shift) // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... - newval = (old & ~(0x000000ff << shift)) | (newval << shift); + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); old = atomicCAS(address_as_ui, assumed, newval); } while (assumed != old); } -// Disable default template instantiation. -// For every type T, we need to define a specialization -// to select the right type for calling atomicCAS. -template -class AtomicCasType; - -template<> -class AtomicCasType { - public: - using type = unsigned short int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; -}; - -template<> -class AtomicCasType { - public: - using type = int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; -}; - // It accumulates `val` into the `address` using the `func`. // This function is thread-safe (i.e., atomic). template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { - ValueType observed = *address, new_value; + ValueType observed = *address, assumed, new_value; using CasType = typename AtomicCasType::type; static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); auto address_as_cas_type = reinterpret_cast(address); do { + // Record the value used to compute new value. + assumed = observed; + // Compute expected new value. new_value = func(observed, val); @@ -312,7 +328,43 @@ __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType // Thus, the following (observed != new_value) is true, // and the loop continues. In the next iteration, the // `new_value` is computed again using the fresh `observed`. - } while (observed != new_value); + } while (observed != assumed); +} + +// This function is similar to atomic_binary_func. +// This function uses unsigned int to contain `val` and +// call atomicCAS. +// `ValueType` can be, for example, int8_t, int16_t, and half. +// Comparing with atomic_binary_func, this function +// adds several bit-level manipulations to +// treat `val` as an unsigned int. +template +__device__ __forceinline__ void atomic_binary_func_with_unsigned_int_cas(ValueType* address, ValueType val, BinaryFunc func) { + using CasType = unsigned int; + + static_assert(sizeof(ValueType) == 8 | sizeof(ValueType) == 16 | sizeof(ValueType) == 32, "ValueType and CasType must have the same size for calling atomicCAS."); + // How many bytes the `address` is higher than + // the closest 4-byte aligned address. + const size_t distance_to_aligned_address_in_byte = (size_t)address & 3; + // Compute the closest 4-byte aligned address lower than `address`. + auto aligned_address_as_cas_type = reinterpret_cast((char*)address - distance_to_aligned_address_in_byte); + // How many bits the `address` is higher than + // the closest 4-byte aligned address. + const size_t distance_to_aligned_address_in_bit = distance_to_aligned_address_in_byte * 8; + + CasType observed_as_cas_type , new_value_as_cas_type, cas_observed_as_cas_type; + do { + observed_as_cas_type = *aligned_address_as_cas_type; + // Extract ValueType from the CasType's binary representation. + ValueType observed = static_cast((observed_as_cas_type >> distance_to_aligned_address_in_bit) & AtomicCasType::mask); + // Compute new value in ValueType world. + ValueType new_value = func(observed, val); + // Prepare new value in CasType world. + CasType clean_observed_as_cas_type = observed_as_cas_type & ~(AtomicCasType::mask << distance_to_aligned_address_in_bit); + // Complete new value in CasType world. + new_value_as_cas_type = clean_observed_as_cas_type | (static_cast(new_value) << distance_to_aligned_address_in_bit); + cas_observed_as_cas_type = atomicCAS(aligned_address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + } while (observed_as_cas_type != cas_observed_as_cas_type); } struct AddFunc { @@ -344,26 +396,38 @@ struct MinFunc { }; __device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, AddFunc()); + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); } __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MulFunc()); + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); } __device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MaxFunc()); + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); } __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MinFunc()); + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); } __device__ __forceinline__ void atomic_mul(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MulFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +#endif } __device__ __forceinline__ void atomic_max(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MaxFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +#endif } __device__ __forceinline__ void atomic_min(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MinFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +#endif } __device__ __forceinline__ void atomic_mul(float* address, float value) { From 736391af35c48b0b2b17dc441e55749ee4999fe9 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 19:18:38 -0800 Subject: [PATCH 09/17] Only keep one impl --- .../core/providers/cuda/atomic/common.cuh | 36 ------------------- 1 file changed, 36 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index f3ec49b6e6458..64f1ad5d29d78 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -331,42 +331,6 @@ __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType } while (observed != assumed); } -// This function is similar to atomic_binary_func. -// This function uses unsigned int to contain `val` and -// call atomicCAS. -// `ValueType` can be, for example, int8_t, int16_t, and half. -// Comparing with atomic_binary_func, this function -// adds several bit-level manipulations to -// treat `val` as an unsigned int. -template -__device__ __forceinline__ void atomic_binary_func_with_unsigned_int_cas(ValueType* address, ValueType val, BinaryFunc func) { - using CasType = unsigned int; - - static_assert(sizeof(ValueType) == 8 | sizeof(ValueType) == 16 | sizeof(ValueType) == 32, "ValueType and CasType must have the same size for calling atomicCAS."); - // How many bytes the `address` is higher than - // the closest 4-byte aligned address. - const size_t distance_to_aligned_address_in_byte = (size_t)address & 3; - // Compute the closest 4-byte aligned address lower than `address`. - auto aligned_address_as_cas_type = reinterpret_cast((char*)address - distance_to_aligned_address_in_byte); - // How many bits the `address` is higher than - // the closest 4-byte aligned address. - const size_t distance_to_aligned_address_in_bit = distance_to_aligned_address_in_byte * 8; - - CasType observed_as_cas_type , new_value_as_cas_type, cas_observed_as_cas_type; - do { - observed_as_cas_type = *aligned_address_as_cas_type; - // Extract ValueType from the CasType's binary representation. - ValueType observed = static_cast((observed_as_cas_type >> distance_to_aligned_address_in_bit) & AtomicCasType::mask); - // Compute new value in ValueType world. - ValueType new_value = func(observed, val); - // Prepare new value in CasType world. - CasType clean_observed_as_cas_type = observed_as_cas_type & ~(AtomicCasType::mask << distance_to_aligned_address_in_bit); - // Complete new value in CasType world. - new_value_as_cas_type = clean_observed_as_cas_type | (static_cast(new_value) << distance_to_aligned_address_in_bit); - cas_observed_as_cas_type = atomicCAS(aligned_address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); - } while (observed_as_cas_type != cas_observed_as_cas_type); -} - struct AddFunc { template __device__ __forceinline__ T operator()(T a, T b) const { From c8c97a83867197d9ac6d05dfebccd61acd14457b Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 19:19:27 -0800 Subject: [PATCH 10/17] Copy change to ROCM --- .../core/providers/rocm/atomic/common.cuh | 139 +++++++++++------- 1 file changed, 84 insertions(+), 55 deletions(-) diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index 5f51034231164..d24250b3c1dda 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -59,22 +59,71 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz atomic_add(start_addr + index, value); } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + // Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. // It accumulate `val` into the `address` using the `func`. // The accumulation is atomic (i.e., thread-safe). // -// E.g., Making OneByteType int8_t and BinaryFunc -// struct AddFunc { -// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const { -// return a + b; -// } -// makes this function atomic_add for int8_t. -template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val, BinaryFunc func) { - static_assert(sizeof(OneByteType) == 1, "OneByteType must be 1 byte for the following bit-level manipulations."); - +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, "ValueType must be 1 byte for the following bit-level manipulations."); // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -144,9 +193,9 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // V // // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... - old_byte = (old >> shift) & 0xff; + old_byte = (old >> shift) & AtomicCasType::mask; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(func(val, static_cast(old_byte))); + newval = static_cast(func(val, static_cast(old_byte))); // Journey of a 32-bit value (cont'd): // // old @@ -169,56 +218,23 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // // (old & ~(0x000000ff << shift)) | (newval << shift) // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... - newval = (old & ~(0x000000ff << shift)) | (newval << shift); + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); old = atomicCAS(address_as_ui, assumed, newval); } while (assumed != old); } -// Disable default template instantiation. -// For every type T, we need to define a specialization -// to select the right type for calling atomicCAS. -template -class AtomicCasType; - -template<> -class AtomicCasType { - public: - using type = unsigned short int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; -}; - -template<> -class AtomicCasType { - public: - using type = int; -}; - -template<> -class AtomicCasType { - public: - using type = unsigned long long int; -}; - // It accumulates `val` into the `address` using the `func`. // This function is thread-safe (i.e., atomic). template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { - ValueType observed = *address, new_value; + ValueType observed = *address, assumed, new_value; using CasType = typename AtomicCasType::type; static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); auto address_as_cas_type = reinterpret_cast(address); do { + // Record the value used to compute new value. + assumed = observed; + // Compute expected new value. new_value = func(observed, val); @@ -249,7 +265,7 @@ __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType // Thus, the following (observed != new_value) is true, // and the loop continues. In the next iteration, the // `new_value` is computed again using the fresh `observed`. - } while (observed != new_value); + } while (observed != assumed); } struct AddFunc { @@ -281,26 +297,38 @@ struct MinFunc { }; __device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, AddFunc()); + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); } __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MulFunc()); + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); } __device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MaxFunc()); + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); } __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value, MinFunc()); + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); } __device__ __forceinline__ void atomic_mul(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MulFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +#endif } __device__ __forceinline__ void atomic_max(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MaxFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +#endif } __device__ __forceinline__ void atomic_min(half* address, half value) { +#if __CUDA_ARCH__ >= 600 atomic_binary_func(address, value, MinFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +#endif } __device__ __forceinline__ void atomic_mul(float* address, float value) { @@ -323,5 +351,6 @@ __device__ __forceinline__ void atomic_min(double* address, double value) { atomic_binary_func(address, value, MinFunc()); } + } // namespace rocm } // namespace onnxruntime From a37393e30f903b9d569ece6f08bc47d0448a8485 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 21:27:38 -0800 Subject: [PATCH 11/17] Try fix support version of atomicCAS --- onnxruntime/core/providers/cuda/atomic/common.cuh | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 64f1ad5d29d78..0e825129d445a 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -186,7 +186,9 @@ class AtomicCasType { // This function becomes atomic_add for int8_t. template __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { - static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, "ValueType must be 1 byte for the following bit-level manipulations."); + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -292,7 +294,8 @@ template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { ValueType observed = *address, assumed, new_value; using CasType = typename AtomicCasType::type; - static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); auto address_as_cas_type = reinterpret_cast(address); do { // Record the value used to compute new value. @@ -373,21 +376,21 @@ __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { } __device__ __forceinline__ void atomic_mul(half* address, half value) { -#if __CUDA_ARCH__ >= 600 +#if __CUDA_ARCH__ >= 700 atomic_binary_func(address, value, MulFunc()); #else atomic_byte_func_with_unit32_cas(address, value, MulFunc()); #endif } __device__ __forceinline__ void atomic_max(half* address, half value) { -#if __CUDA_ARCH__ >= 600 +#if __CUDA_ARCH__ >= 700 atomic_binary_func(address, value, MaxFunc()); #else atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); #endif } __device__ __forceinline__ void atomic_min(half* address, half value) { -#if __CUDA_ARCH__ >= 600 +#if __CUDA_ARCH__ >= 700 atomic_binary_func(address, value, MinFunc()); #else atomic_byte_func_with_unit32_cas(address, value, MinFunc()); From d456ee0200105dd64536ab4aaa047073a75b3d81 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 21:28:01 -0800 Subject: [PATCH 12/17] exclude openvino & tensorrt for new tests --- .../providers/cpu/tensor/scatter_op_test.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 94a02c08281a8..30e27bb15fa57 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -312,7 +312,7 @@ TEST(ScatterElements, AddReduction) { test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, AddReductionAxis1) { @@ -326,7 +326,7 @@ TEST(ScatterElements, AddReductionAxis1) { test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MulReduction) { @@ -339,7 +339,7 @@ TEST(ScatterElements, MulReduction) { test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MulReductionAxis1) { @@ -353,7 +353,7 @@ TEST(ScatterElements, MulReductionAxis1) { test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MaxReduction_MLFloat16) { @@ -366,7 +366,7 @@ TEST(ScatterElements, MaxReduction_MLFloat16) { test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f})); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MaxReduction_Float) { @@ -379,7 +379,7 @@ TEST(ScatterElements, MaxReduction_Float) { test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MaxReduction_Double) { @@ -392,7 +392,7 @@ TEST(ScatterElements, MaxReduction_Double) { test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MinReduction_MLFloat16) { @@ -405,7 +405,7 @@ TEST(ScatterElements, MinReduction_MLFloat16) { test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f})); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MinReduction_Float) { @@ -418,7 +418,7 @@ TEST(ScatterElements, MinReduction_Float) { test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(ScatterElements, MinReduction_Double) { @@ -431,7 +431,7 @@ TEST(ScatterElements, MinReduction_Double) { test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } } // namespace test From 2970dc8d4862c965c0c5d46f2f56704d64b99039 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 21:28:23 -0800 Subject: [PATCH 13/17] define ScatterElements in ROCM --- .../core/providers/rocm/rocm_execution_provider.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index fff3d14b763d5..aace78ea858ce 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); - +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 From c77f23630ef364e25f270079356c10391fc5b30d Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 22:28:17 -0800 Subject: [PATCH 14/17] fix copy-and-paste --- onnxruntime/core/providers/rocm/rocm_execution_provider.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index aace78ea858ce..ee3578326ac6d 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1290,7 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1303,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 From 53e07672976d6254aa72d9d84df58dcfc7633cc5 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 23:19:10 -0800 Subject: [PATCH 15/17] ROCm doesn't support 16-bit atomicCAS. Let's fallback. --- onnxruntime/core/providers/rocm/atomic/common.cuh | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index d24250b3c1dda..7048d6aff6428 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -310,25 +310,13 @@ __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { } __device__ __forceinline__ void atomic_mul(half* address, half value) { -#if __CUDA_ARCH__ >= 600 - atomic_binary_func(address, value, MulFunc()); -#else atomic_byte_func_with_unit32_cas(address, value, MulFunc()); -#endif } __device__ __forceinline__ void atomic_max(half* address, half value) { -#if __CUDA_ARCH__ >= 600 - atomic_binary_func(address, value, MaxFunc()); -#else atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); -#endif } __device__ __forceinline__ void atomic_min(half* address, half value) { -#if __CUDA_ARCH__ >= 600 - atomic_binary_func(address, value, MinFunc()); -#else atomic_byte_func_with_unit32_cas(address, value, MinFunc()); -#endif } __device__ __forceinline__ void atomic_mul(float* address, float value) { From 55cc23ad9bc2b4578f8ef99445a0017506df7b37 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 23:28:33 -0800 Subject: [PATCH 16/17] Hack to add doc change since doc generation code doesn't work on my end. --- docs/OperatorKernels.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9d9b266355335..2ea557b7d61fe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -744,7 +744,9 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From 1b038805b09ccbe6531b6e2fe46997de4ef1f5ef Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Sat, 27 Jan 2024 01:52:12 -0800 Subject: [PATCH 17/17] Fix a test --- .../core/providers/cuda/atomic/common.cuh | 19 ++++++++++++-- .../core/providers/rocm/atomic/common.cuh | 26 ++++++++++++++++--- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 0e825129d445a..170aa3a2d8d0c 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -259,8 +259,23 @@ __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* addr // // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & AtomicCasType::mask; - // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(func(val, static_cast(old_byte))); + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; // Journey of a 32-bit value (cont'd): // // old diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index 7048d6aff6428..b5d01b91c70ed 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -123,7 +123,9 @@ class AtomicCasType { // This function becomes atomic_add for int8_t. template __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { - static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, "ValueType must be 1 byte for the following bit-level manipulations."); + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -194,8 +196,23 @@ __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* addr // // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & AtomicCasType::mask; - // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(func(val, static_cast(old_byte))); + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; // Journey of a 32-bit value (cont'd): // // old @@ -229,7 +246,8 @@ template __device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { ValueType observed = *address, assumed, new_value; using CasType = typename AtomicCasType::type; - static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); auto address_as_cas_type = reinterpret_cast(address); do { // Record the value used to compute new value.