From 1e9a0a72246880320bfe6581e68761c121354d9e Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 25 Jan 2024 19:47:34 -0800 Subject: [PATCH] Fix atomic --- .../core/providers/cuda/atomic/common.cuh | 59 ++++++++++ .../cuda/tensor/gather_elements_impl.cu | 4 +- .../providers/cpu/tensor/scatter_op_test.cc | 106 +++++++++++++++--- 3 files changed, 153 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 14fa2d0706f73..c8d0b82a8502f 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -25,6 +25,16 @@ namespace onnxruntime { namespace cuda { +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) { + int* address_as_int = reinterpret_cast(address); + int old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + static_cast(val) + assumed); + } while (assumed != old); +} + __device__ __forceinline__ void atomic_add(float *address, float value) { atomicAdd(address, value); } @@ -122,5 +132,54 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +__device__ __forceinline__ void atomic_mul(half* address, half val) { + unsigned short int* address_as_short = reinterpret_cast(address); + unsigned short int old = *address_as_short, assumed; + do { + assumed = old; + old = atomicCAS(address_as_short, assumed, + __half_as_short(val * __short_as_half(assumed))); + } while (assumed != old); +} + +__device__ __forceinline__ void atomic_mul(float* address, float val) { + int* address_as_int = reinterpret_cast(address); + int old = *address_as_int, assumed; + do { + assumed = old; + old = atomicCAS(address_as_int, assumed, + __float_as_int(val * __int_as_float(assumed))); + } while (assumed != old); +} + +__device__ __forceinline__ void atomic_mul(double* address, double val) { + unsigned long long int* address_as_long_long = reinterpret_cast(address); + unsigned long long int old = *address_as_long_long, assumed; + do { + assumed = old; + old = atomicCAS(address_as_long_long, assumed, + __double_as_longlong(val * __longlong_as_double(assumed))); + } while (assumed != old); +} + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { + size_t offset = (size_t)address & 3; \ + uint32_t * address_as_ui = (uint32_t *)((char *)address - offset); \ + uint32_t old = *address_as_ui; \ + uint32_t shift = offset * 8; \ + uint32_t old_byte; \ + uint32_t newval; \ + uint32_t assumed; \ + \ + do { \ + assumed = old; \ + old_byte = (old >> shift) & 0xff; \ + newval = static_cast(val * static_cast(old_byte)); \ + newval = (old & ~(0x000000ff << shift)) | (newval << shift); \ + old = atomicCAS(address_as_ui, assumed, newval); \ + } while (assumed != old); \ +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 4dacceb6e6af7..99c4bc671d000 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -100,12 +100,12 @@ struct FuncAssignment { template struct FuncAdd { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] += value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_add(start_addr + index, value); } }; template struct FuncMul { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] *= value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { atomic_mul(start_addr + index, value); } }; template diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 81dd306f6bff1..8f93fd78e5537 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -75,20 +75,18 @@ void RunTest(const std::vector& input_dims, const std::vector& test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } - for (int opset : {11, 18}) { - onnxruntime::test::OpTester test1("ScatterElements", opset); - if (has_axis) test1.AddAttribute("axis", axis); - test1.AddInput("data", input_dims, input_data); - test1.AddInput("indices", indices_dims, indices_data); - test1.AddInput("updates", indices_dims, updates_data); - test1.AddOutput("y", input_dims, output_data); - if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); - } else if (std::is_same::value) { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } else { - test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); - } + onnxruntime::test::OpTester test1("ScatterElements", 11); + if (has_axis) test1.AddAttribute("axis", axis); + test1.AddInput("data", input_dims, input_data); + test1.AddInput("indices", indices_dims, indices_data); + test1.AddInput("updates", indices_dims, updates_data); + test1.AddOutput("y", input_dims, output_data); + if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + } else if (std::is_same::value) { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + } else { + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } } @@ -304,5 +302,85 @@ TEST(Scatter, BoolInputWithAxis) { scatter_bool_with_axis_tests("ScatterElements", 11); } +TEST(ScatterElements, AddReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "add"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, AddReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "add"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MulReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "mul"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MulReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "mul"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + +TEST(ScatterElements, MinReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider}); +} + } // namespace test } // namespace onnxruntime