From ce51bc6278fbe187f14fe373d24907bad0ddd9fd Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 26 Jan 2024 12:49:57 -0800 Subject: [PATCH] Rename a type --- onnxruntime/core/providers/cuda/atomic/common.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 955c583bf8b54..bb15882745bb3 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -156,14 +156,14 @@ __device__ __forceinline__ void atomic_mul(double* address, double val) { // // This function compute 8-bit atomic binary operation using 32-bit atomicCAS. // -// E.g., Making ByteLargeType int8_t and BinaryFunc +// E.g., Making OneByteType int8_t and BinaryFunc // struct AddFunc { -// __device__ __forceinline__ ByteLargeType operator()(ByteLargeType a, ByteLargeType b) const { +// __device__ __forceinline__ OneByteType operator()(OneByteType a, OneByteType b) const { // return a + b; // } // makes this function atomic_add for int8_t. -template -__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* address, ByteLargeType val) { +template +__device__ __forceinline__ void atomic_byte_func_with_4byte_cas(OneByteType* address, OneByteType val) { // Number of bytes to the lower 4-byte aligned address. // If the current address is b1010"10", then offset = b10 = 2, // which means the current address is 2 bytes away from @@ -235,7 +235,7 @@ __device__ __forceinline__ void atomic_byte_func_with_4byte_cas(ByteLargeType* a // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... old_byte = (old >> shift) & 0xff; // Use + for atomic addition, * for atomic multiplication, / for atomic division. - newval = static_cast(val * static_cast(old_byte)); + newval = static_cast(val * static_cast(old_byte)); // Journey of a 32-bit value (cont'd): // // old @@ -280,14 +280,14 @@ struct MulFunc { struct MaxFunc { template __device__ __forceinline__ T operator()(T a, T b) const { - return a > b ? a : b; + return b > a ? b : a; } }; struct MinFunc { template __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? a : b; + return b < a ? b : a; } };