Skip to content

Commit

Permalink
Fix a test
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 27, 2024
1 parent 55cc23a commit 1b03880
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
19 changes: 17 additions & 2 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,23 @@ __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* addr
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Use + for atomic addition, * for atomic multiplication, / for atomic division.
newval = static_cast<uint32_t>(func(val, static_cast<ValueType>(old_byte)));
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
Expand Down
26 changes: 22 additions & 4 deletions onnxruntime/core/providers/rocm/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ class AtomicCasType<int64_t> {
// This function becomes atomic_add for int8_t.
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, "ValueType must be 1 byte for the following bit-level manipulations.");
// Assert to ensure the following bit-wise manipulation is correct.
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
"ValueType must be 1-byte, 2-byte or 4-byte large.");
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
Expand Down Expand Up @@ -194,8 +196,23 @@ __device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* addr
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Use + for atomic addition, * for atomic multiplication, / for atomic division.
newval = static_cast<uint32_t>(func(val, static_cast<ValueType>(old_byte)));
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
Expand Down Expand Up @@ -229,7 +246,8 @@ template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
ValueType observed = *address, assumed, new_value;
using CasType = typename AtomicCasType<ValueType>::type;
static_assert(sizeof(ValueType) == sizeof(CasType), "ValueType and CasType must have the same size for calling atomicCAS.");
static_assert(sizeof(ValueType) == sizeof(CasType),
"ValueType and CasType must have the same size for calling atomicCAS.");
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
do {
// Record the value used to compute new value.
Expand Down

0 comments on commit 1b03880

Please sign in to comment.