Skip to content

Commit

Permalink
Rename a type
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 26, 2024
1 parent 469245e commit ce51bc6
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) {
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
//
// E.g., Making ByteLargeType int8_t and BinaryFunc
// E.g., Making OneByteType int8_t and BinaryFunc
// struct AddFunc {
// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const {
// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const {
// return a + b;
// }
// makes this function atomic_add for int8_t.
template<typename ByteLargeType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) {
template<typename OneByteType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val) {
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
Expand Down Expand Up @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* a
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & 0xff;
// Use + for atomic addition, * for atomic multiplication, / for atomic division.
newval = static_cast<uint32_t>(val * static_cast<ByteLargeType>(old_byte));
newval = static_cast<uint32_t>(val * static_cast<OneByteType>(old_byte));
// Journey of a 32-bit value (cont'd):
//
// old
Expand Down Expand Up @@ -280,14 +280,14 @@ struct MulFunc {
struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a > b ? a : b;
return b > a ? b : a;
}
};

struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? a : b;
return b < a ? b : a;
}
};

Expand Down

0 comments on commit ce51bc6

Please sign in to comment.