From 469245e0b62631f9bc93873605b4e54ca0018016 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 12:48:13 -0800 Subject: [PATCH] Implement int8_t using template function --- .../core/providers/cuda/atomic/common.cuh | 70 +++++++++++++++---- 1 file changed, 57 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index a7520558922ed..955c583bf8b54 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -25,16 +25,6 @@ namespace onnxruntime { namespace cuda { -__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) { - int* address_as_int = reinterpret_cast(address); - int old = *address_as_int, assumed; - do { - assumed = old; - old = atomicCAS(address_as_int, assumed, - static_cast(val) + assumed); - } while (assumed != old); -} - __device__ __forceinline__ void atomic_add(float *address, float value) { atomicAdd(address, value); } @@ -162,8 +152,18 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { } while (assumed != old); } -// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. -__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// +// E.g., Making ByteLargeType int8_t and BinaryFunc +// struct AddFunc { +// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const { +// return a + b; +// } +// makes this function atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) { // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(val * static_cast(old_byte)); // Journey of a 32-bit value (cont'd): // // old @@ -263,5 +263,49 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) { } while (assumed != old); } +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a > b ? a : b; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? a : b; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_4byte_cas(address, value); +} + } // namespace cuda } // namespace onnxruntime