Skip to content

Commit

Permalink
Update ScatterElements to Support Opset 13, 15, 18 (#19198)
Browse files Browse the repository at this point in the history
`ScatterElements` in opset 18 has been around for a while. However, the
highest opset supporting `ScatterElements` in ORT is 13. This PR
implement this op in CUDA EP by replacing `assignment` in the current
CDUA kernel with `atomic reduction` (e.g., atomic add, atomic max). A
series of fundamental atomic functions (e.g., atomic max for int8_t and
half) are implemented in `common.cuh`; the implementation is general
enough to cover old CUDA and new CUDA versions.

- The core changes are in `cuda/atomic/common.cuh` with very detailed
documentation including `bit-wise operation's visualization`. They are
also copied to `rocm/atomic/common.cuh` to support AMD GPU.
- `/cuda/tensor/gather_elements_impl.cu` contains small changes to call
the new atomic functions to support new `reduction` behavior in new
`ScatterElements`.
- New `ScatterElements` are defined in `rocm_execution_provider.cc` and
`cuda_execution_provider.cc`.
  • Loading branch information
wschin authored Jan 30, 2024
1 parent 3e17ca3 commit ffc3431
Show file tree
Hide file tree
Showing 11 changed files with 858 additions and 25 deletions.
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,9 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
14 changes: 0 additions & 14 deletions onnxruntime/core/providers/cpu/tensor/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,6 @@ struct Func_Min<std::string> {
}
};

template <>
struct Func_Min<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};

template <>
struct Func_Min<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
Expand Down Expand Up @@ -233,13 +226,6 @@ struct Func_Max<std::string> {
}
};

template <>
struct Func_Max<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};

template <>
struct Func_Max<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
Expand Down
311 changes: 311 additions & 0 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd<half>(half* start_addr, size_t index,
#endif
}

// Disable default template instantiation.
// For every type T, we need to define a specialization
// to select the right type for calling atomicCAS.
template <typename T>
class AtomicCasType;

template<>
class AtomicCasType<int8_t> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffu;
};

template<>
class AtomicCasType<half> {
public:
using type = unsigned short int;
static const unsigned int mask = 0xffffu;
};

template<>
class AtomicCasType<float> {
public:
using type = unsigned int;
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<double> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<int> {
public:
using type = int;
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<int64_t> {
public:
using type = unsigned long long int;
static const unsigned int mask = 0xffffffffu;
};

// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
// It accumulate `val` into the `address` using the `func`.
// The accumulation is atomic (i.e., thread-safe).
//
// E.g., Assume ValueType is
// int8_t
// and BinaryFunc is
// struct AddFunc {
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
// return a + b;
// }
// This function becomes atomic_add for int8_t.
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
// Assert to ensure the following bit-wise manipulation is correct.
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
"ValueType must be 1-byte, 2-byte or 4-byte large.");
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
// the lower 4-byte aligned address b1010"00".
size_t offset = (size_t)address & 3;
// Find an new 4-byte aligned address `address_as_ui` lower than
// or equal to `address`. Lower than `address` so that the actual
// int8_t byte is in the 4-byte word that we load.
//
// This address has the following properties:
// 1. It is 4-byte aligned.
// 2. It is lower than or equal to `address`.
// 3. De-referencing this address may return
// a uint32_t value that contains the same int8_t
// value indicated by `address`.
//
// E.g.,
// address = b101010
// offset = b101010 & b000011 = b10 = 2
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
// which is (32-bit aligned).
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
uint32_t old = *address_as_ui;
// E.g., offset = 2.
// address_as_ui is an address 2 bytes lower than `address`.
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
// ^ ^ ^
// | | |
// | address <--- offset * 8 (bit)-----> address_as_ui
// | ^
// | |
// ------------------------- *address_as_ui -----------------------
//
// This visualization shows
// 1. the 32-bit word at address_as_ui.
// 2. the gap between address_as_ui and address.
// 3. *address_as_ui contains the int8_t value at `address`.
uint32_t shift = offset * 8;
uint32_t old_byte;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
// we want to select the 3rd byte (byte 2 below) from the word.
//
// Journey of a 32-bit value:
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// |
// | old >> offset * 8, where offset = 2.
// | Effectively, push lower two bytes
// | out of the word.
// V
//
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
//
// | apply bit-wise AND,
// | & 0xff (i.e., & b11111111),
// | so that we only keep
// | the byte of interest.
// | Otherwise, overflow may
// | happen when casting this
// | 32-bit value to int8_t.
// V
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// 0x000000ff
// 00000000 | 00000000 | 00000000 | 11111111
//
// 0x000000ff << shift
// 00000000 | 11111111 | 00000000 | 00000000
//
// ~(0x000000ff << shift)
// 11111111 | 00000000 | 11111111 | 11111111
//
// old & ~(0x000000ff << shift)
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
//
// newval << shift
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
//
// (old & ~(0x000000ff << shift)) | (newval << shift)
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}

// It accumulates `val` into the `address` using the `func`.
// This function is thread-safe (i.e., atomic).
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
ValueType observed = *address, assumed, new_value;
using CasType = typename AtomicCasType<ValueType>::type;
static_assert(sizeof(ValueType) == sizeof(CasType),
"ValueType and CasType must have the same size for calling atomicCAS.");
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
do {
// Record the value used to compute new value.
assumed = observed;

// Compute expected new value.
new_value = func(observed, val);

// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
// 4
// 8
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);

// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
// 4 unsigned int (or int)
// 8 unsigned long long int
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);

// Cast the freshly observed value in memory back to the TwoByteType.
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);

// Two cases:
// 1. compare-and-swap success
// a. `address` holds `new_value`
// b. `observed` becomes the new value after the assignment.
// Thus, the following `observed != new_value` is false,
// and the loop terminates.
// 2. compare-and-swap fails
// a. `address` holds a value different from `observed`, thus,
// the `new_value` is stale.
// b. `observed` becomes the fresh value observed in `address`.
// Thus, the following (observed != new_value) is true,
// and the loop continues. In the next iteration, the
// `new_value` is computed again using the fresh `observed`.
} while (observed != assumed);
}

struct AddFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};

struct MulFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a * b;
}
};

struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b > a ? b : a;
}
};

struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b < a ? b : a;
}
};

__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
}
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
}

__device__ __forceinline__ void atomic_mul(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MulFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
#endif
}
__device__ __forceinline__ void atomic_max(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MaxFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
#endif
}
__device__ __forceinline__ void atomic_min(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MinFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
#endif
}

__device__ __forceinline__ void atomic_mul(float* address, float value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(float* address, float value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(float* address, float value) {
atomic_binary_func(address, value, MinFunc());
}

__device__ __forceinline__ void atomic_mul(double* address, double value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(double* address, double value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(double* address, double value) {
atomic_binary_func(address, value, MinFunc());
}


} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit ffc3431

Please sign in to comment.