diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index bb15882745bb3..4f5223a59543a 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,39 +122,11 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } -__device__ __forceinline__ void atomic_mul(half* address, half val) { - unsigned short int* address_as_short = reinterpret_cast(address); - unsigned short int old = *address_as_short, assumed; - do { - assumed = old; - old = atomicCAS(address_as_short, assumed, - __half_as_short(val * __short_as_half(assumed))); - } while (assumed != old); -} - -__device__ __forceinline__ void atomic_mul(float* address, float val) { - int* address_as_int = reinterpret_cast(address); - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - __float_as_int(val * __int_as_float(assumed))); - } while (assumed != old); -} - -__device__ __forceinline__ void atomic_mul(double* address, double val) { - unsigned long long int* address_as_long_long = reinterpret_cast(address); - unsigned long long int old = *address_as_long_long, assumed; - do { - assumed = old; - old = atomicCAS(address_as_long_long, assumed, - __double_as_longlong(val * __longlong_as_double(assumed))); - } while (assumed != old); -} - // Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). // // E.g., Making OneByteType int8_t and BinaryFunc // struct AddFunc { @@ -163,7 +135,9 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { // } // makes this function atomic_add for int8_t. template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val) { +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val, BinaryFunc func) { + static_assert(sizeof(OneByteType) == 1, "OneByteType must be 1 byte for the following bit-level manipulations."); + // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +209,7 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(func(val, static_cast(old_byte))); // Journey of a 32-bit value (cont'd): // // old @@ -263,6 +237,84 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* add } while (assumed != old); } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +template<> +class AtomicCasType { + public: + using type = int; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; +}; + +// It accumulates `val` into the `address` using the `func`. +// This function function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != new_value); +} + struct AddFunc { template __device__ __forceinline__ T operator()(T a, T b) const { @@ -292,20 +344,48 @@ struct MinFunc { }; __device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, AddFunc()); } - __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MulFunc()); } - __device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MaxFunc()); } - __device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { - atomic_byte_func_with_4byte_cas(address, value); + atomic_byte_func_with_4byte_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(half* address, half value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(half* address, half value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); } +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + } // namespace cuda } // namespace onnxruntime