Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ScatterElements to Support Opset 13, 15, 18 #19198

Merged
merged 17 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,9 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)<br/> **V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
14 changes: 0 additions & 14 deletions onnxruntime/core/providers/cpu/tensor/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,6 @@ struct Func_Min<std::string> {
}
};

template <>
struct Func_Min<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
}
};

template <>
struct Func_Min<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
Expand Down Expand Up @@ -233,13 +226,6 @@ struct Func_Max<std::string> {
}
};

template <>
struct Func_Max<MLFloat16> {
void operator()(MLFloat16*, const MLFloat16*) const {
ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
}
};

template <>
struct Func_Max<BFloat16> {
void operator()(BFloat16*, const BFloat16*) const {
Expand Down
311 changes: 311 additions & 0 deletions onnxruntime/core/providers/cuda/atomic/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,316 @@
#endif
}

// Disable default template instantiation.
// For every type T, we need to define a specialization
// to select the right type for calling atomicCAS.
template <typename T>
class AtomicCasType;

template<>
class AtomicCasType<int8_t> {
public:
using type = unsigned short int;

Check warning on line 134 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L134

Use int16/int64/etc, rather than the C type short [runtime/int] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:134:  Use int16/int64/etc, rather than the C type short  [runtime/int] [4]
static const unsigned int mask = 0xffu;
};

template<>
class AtomicCasType<half> {
public:
using type = unsigned short int;

Check warning on line 141 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L141

Use int16/int64/etc, rather than the C type short [runtime/int] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:141:  Use int16/int64/etc, rather than the C type short  [runtime/int] [4]
static const unsigned int mask = 0xffffu;
};

template<>
class AtomicCasType<float> {
public:
using type = unsigned int;
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<double> {
public:
using type = unsigned long long int;

Check warning on line 155 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L155

Use int16/int64/etc, rather than the C type long [runtime/int] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:155:  Use int16/int64/etc, rather than the C type long  [runtime/int] [4]
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<int> {
public:
using type = int;
static const unsigned int mask = 0xffffffffu;
};

template<>
class AtomicCasType<int64_t> {
public:
using type = unsigned long long int;

Check warning on line 169 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L169

Use int16/int64/etc, rather than the C type long [runtime/int] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:169:  Use int16/int64/etc, rather than the C type long  [runtime/int] [4]
static const unsigned int mask = 0xffffffffu;
};

// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
//
// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
// It accumulate `val` into the `address` using the `func`.
// The accumulation is atomic (i.e., thread-safe).
//
// E.g., Assume ValueType is
// int8_t
// and BinaryFunc is
// struct AddFunc {
// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
// return a + b;
// }
// This function becomes atomic_add for int8_t.
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
// Assert to ensure the following bit-wise manipulation is correct.
static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
"ValueType must be 1-byte, 2-byte or 4-byte large.");
// Number of bytes to the lower 4-byte aligned address.
// If the current address is b1010"10", then offset = b10 = 2,
// which means the current address is 2 bytes away from
// the lower 4-byte aligned address b1010"00".
size_t offset = (size_t)address & 3;

Check warning on line 196 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L196

Using C-style cast. Use static_cast<size_t>(...) instead [readability/casting] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:196:  Using C-style cast.  Use static_cast<size_t>(...) instead  [readability/casting] [4]
// Find an new 4-byte aligned address `address_as_ui` lower than
// or equal to `address`. Lower than `address` so that the actual
// int8_t byte is in the 4-byte word that we load.
//
// This address has the following properties:
// 1. It is 4-byte aligned.
// 2. It is lower than or equal to `address`.
// 3. De-referencing this address may return
// a uint32_t value that contains the same int8_t
// value indicated by `address`.
//
// E.g.,
// address = b101010
// offset = b101010 & b000011 = b10 = 2
// (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
// which is (32-bit aligned).
uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);

Check warning on line 213 in onnxruntime/core/providers/cuda/atomic/common.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/atomic/common.cuh#L213

Using C-style cast. Use reinterpret_cast<uint32_t*>(...) instead [readability/casting] [4]
Raw output
onnxruntime/core/providers/cuda/atomic/common.cuh:213:  Using C-style cast.  Use reinterpret_cast<uint32_t*>(...) instead  [readability/casting] [4]
uint32_t old = *address_as_ui;
// E.g., offset = 2.
// address_as_ui is an address 2 bytes lower than `address`.
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
// ^ ^ ^
// | | |
// | address <--- offset * 8 (bit)-----> address_as_ui
// | ^
// | |
// ------------------------- *address_as_ui -----------------------
//
// This visualization shows
// 1. the 32-bit word at address_as_ui.
// 2. the gap between address_as_ui and address.
// 3. *address_as_ui contains the int8_t value at `address`.
uint32_t shift = offset * 8;
uint32_t old_byte;
uint32_t newval;
uint32_t assumed;
do {
assumed = old;
// Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
// we want to select the 3rd byte (byte 2 below) from the word.
//
// Journey of a 32-bit value:
//
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// |
// | old >> offset * 8, where offset = 2.
// | Effectively, push lower two bytes
// | out of the word.
// V
//
// 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
//
// | apply bit-wise AND,
// | & 0xff (i.e., & b11111111),
// | so that we only keep
// | the byte of interest.
// | Otherwise, overflow may
// | happen when casting this
// | 32-bit value to int8_t.
// V
//
// 00000000 | 00000000 | 00000000 | ..... byte 2 .....
old_byte = (old >> shift) & AtomicCasType<ValueType>::mask;
// Compute new int8_t value and store it to newrawvalue.
// Journey of a 32-bit value (cont'd):
//
// newrawvalue
// ... new byte 2 ...
auto newrawvalue = func(val, reinterpret_cast<ValueType&>(old_byte));
// Put the new int8_t value back to 32-bit word.
// Also ensure that bits not occupied by the int8_t value are 0s.
//
// Journey of a 32-bit value (cont'd):
//
// reinterpret_cast<uint32_t&>(newrawvalue)
// random values | random values | random values | ... new byte 2 ...
//
// reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask
// 00000000 | 00000000 | 00000000 | ... new byte 2 ...
newval = reinterpret_cast<uint32_t&>(newrawvalue) & AtomicCasType<ValueType>::mask;
// Journey of a 32-bit value (cont'd):
//
// old
// ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
//
// 0x000000ff
// 00000000 | 00000000 | 00000000 | 11111111
//
// 0x000000ff << shift
// 00000000 | 11111111 | 00000000 | 00000000
//
// ~(0x000000ff << shift)
// 11111111 | 00000000 | 11111111 | 11111111
//
// old & ~(0x000000ff << shift)
// ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
//
// newval << shift
// 00000000 | ... new byte 2 ... | 00000000 | 00000000
//
// (old & ~(0x000000ff << shift)) | (newval << shift)
// ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
newval = (old & ~(AtomicCasType<ValueType>::mask << shift)) | (newval << shift);
old = atomicCAS(address_as_ui, assumed, newval);
} while (assumed != old);
}

// It accumulates `val` into the `address` using the `func`.
// This function is thread-safe (i.e., atomic).
template<typename ValueType, typename BinaryFunc>
__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
ValueType observed = *address, assumed, new_value;
using CasType = typename AtomicCasType<ValueType>::type;
static_assert(sizeof(ValueType) == sizeof(CasType),
"ValueType and CasType must have the same size for calling atomicCAS.");
auto address_as_cas_type = reinterpret_cast<CasType*>(address);
do {
// Record the value used to compute new value.
assumed = observed;

// Compute expected new value.
new_value = func(observed, val);

// Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
// 4
// 8
auto observed_as_cas_type = *reinterpret_cast<CasType*>(&observed);
auto new_value_as_cas_type = *reinterpret_cast<CasType*>(&new_value);

// Call atomicCAS as if the 2-byte type variables are all unsigned short int.
// 4 unsigned int (or int)
// 8 unsigned long long int
auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);

// Cast the freshly observed value in memory back to the TwoByteType.
observed = *reinterpret_cast<ValueType*>(&cas_observed_as_cas_type);

// Two cases:
// 1. compare-and-swap success
// a. `address` holds `new_value`
// b. `observed` becomes the new value after the assignment.
// Thus, the following `observed != new_value` is false,
// and the loop terminates.
// 2. compare-and-swap fails
// a. `address` holds a value different from `observed`, thus,
// the `new_value` is stale.
// b. `observed` becomes the fresh value observed in `address`.
// Thus, the following (observed != new_value) is true,
// and the loop continues. In the next iteration, the
// `new_value` is computed again using the fresh `observed`.
} while (observed != assumed);
}

struct AddFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};

struct MulFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return a * b;
}
};

struct MaxFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b > a ? b : a;
}
};

struct MinFunc {
template <typename T>
__device__ __forceinline__ T operator()(T a, T b) const {
return b < a ? b : a;
}
};

__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, AddFunc());
}
__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
}

__device__ __forceinline__ void atomic_mul(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MulFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MulFunc());
#endif
}
__device__ __forceinline__ void atomic_max(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MaxFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
#endif
}
__device__ __forceinline__ void atomic_min(half* address, half value) {
#if __CUDA_ARCH__ >= 700
atomic_binary_func(address, value, MinFunc());
#else
atomic_byte_func_with_unit32_cas(address, value, MinFunc());
#endif
}

__device__ __forceinline__ void atomic_mul(float* address, float value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(float* address, float value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(float* address, float value) {
atomic_binary_func(address, value, MinFunc());
}

__device__ __forceinline__ void atomic_mul(double* address, double value) {
atomic_binary_func(address, value, MulFunc());
}
__device__ __forceinline__ void atomic_max(double* address, double value) {
atomic_binary_func(address, value, MaxFunc());
}
__device__ __forceinline__ void atomic_min(double* address, double value) {
atomic_binary_func(address, value, MinFunc());
}


} // namespace cuda
} // namespace onnxruntime
Loading
Loading