Skip to content

Commit

Permalink
Implement int8_t using template function
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 26, 2024
1 parent 5274efe commit 469245e
Showing 1 changed file with 57 additions and 13 deletions.
70 changes: 57 additions & 13 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@
namespace onnxruntime {
namespace cuda {

__device__ __forceinline__ void atomic_add(int8_t* address, int8_t val) {
int* address_as_int = reinterpret_cast<int*>(address);
int old = *address_as_int, assumed;
do {
assumed = old;
old = atomicCAS(address_as_int, assumed,
static_cast<int>(val) + assumed);
} while (assumed != old);
}

__device__ __forceinline__ void atomic_add(float *address, float value) {
atomicAdd(address, value);
}
Expand Down Expand Up @@ -162,8 +152,18 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) {
} while (assumed != old);
}

// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) {
// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
//
// E.g., Making ByteLargeType int8_t and BinaryFunc
// struct AddFunc {
// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const {
// return a + b;
// }
// makes this function atomic_add for int8_t.
template<typename ByteLargeType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) {
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
Expand Down Expand Up @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) {
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & 0xff;
// Use + for atomic addition, * for atomic multiplication, / for atomic division.
newval = static_cast<uint32_t>(val * static_cast<int8_t>(old_byte));
newval = static_cast<uint32_t>(val * static_cast<ByteLargeType>(old_byte));
// Journey of a 32-bit value (cont'd):
//
// old
Expand Down Expand Up @@ -263,5 +263,49 @@ __device__ __forceinline__ void atomic_mul(int8_t* address, int8_t val) {
} while (assumed != old);
}

struct AddFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};

struct MulFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a * b;
}
};

struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a > b ? a : b;
}
};

struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? a : b;
}
};

__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
atomic_byte_func_with_4byte_cas<int8_t, AddFunc>(address, value);
}

__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
atomic_byte_func_with_4byte_cas<int8_t, MulFunc>(address, value);
}

__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
atomic_byte_func_with_4byte_cas<int8_t, MaxFunc>(address, value);
}

__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
atomic_byte_func_with_4byte_cas<int8_t, MinFunc>(address, value);
}

} // namespace cuda
} // namespace onnxruntime

0 comments on commit 469245e

Please sign in to comment.