diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 9d9b266355335..2ea557b7d61fe 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -744,7 +744,9 @@ Do not modify directly.*
|||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
+|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc
index 8844b7e7a26c4..c7a2005924836 100644
--- a/onnxruntime/core/providers/cpu/tensor/scatter.cc
+++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc
@@ -198,13 +198,6 @@ struct Func_Min {
}
};
-template <>
-struct Func_Min {
- void operator()(MLFloat16*, const MLFloat16*) const {
- ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'.");
- }
-};
-
template <>
struct Func_Min {
void operator()(BFloat16*, const BFloat16*) const {
@@ -233,13 +226,6 @@ struct Func_Max {
}
};
-template <>
-struct Func_Max {
- void operator()(MLFloat16*, const MLFloat16*) const {
- ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'.");
- }
-};
-
template <>
struct Func_Max {
void operator()(BFloat16*, const BFloat16*) const {
diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh
index 14fa2d0706f73..170aa3a2d8d0c 100644
--- a/onnxruntime/core/providers/cuda/atomic/common.cuh
+++ b/onnxruntime/core/providers/cuda/atomic/common.cuh
@@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index,
#endif
}
+// Disable default template instantiation.
+// For every type T, we need to define a specialization
+// to select the right type for calling atomicCAS.
+template
+class AtomicCasType;
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned short int;
+ static const unsigned int mask = 0xffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned short int;
+ static const unsigned int mask = 0xffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned long long int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned long long int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
+//
+// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
+// It accumulate `val` into the `address` using the `func`.
+// The accumulation is atomic (i.e., thread-safe).
+//
+// E.g., Assume ValueType is
+// int8_t
+// and BinaryFunc is
+// struct AddFunc {
+// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
+// return a + b;
+// }
+// This function becomes atomic_add for int8_t.
+template
+__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
+ // Assert to ensure the following bit-wise manipulation is correct.
+ static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
+ "ValueType must be 1-byte, 2-byte or 4-byte large.");
+ // Number of bytes to the lower 4-byte aligned address.
+ // If the current address is b1010"10", then offset = b10 = 2,
+ // which means the current address is 2 bytes away from
+ // the lower 4-byte aligned address b1010"00".
+ size_t offset = (size_t)address & 3;
+ // Find an new 4-byte aligned address `address_as_ui` lower than
+ // or equal to `address`. Lower than `address` so that the actual
+ // int8_t byte is in the 4-byte word that we load.
+ //
+ // This address has the following properties:
+ // 1. It is 4-byte aligned.
+ // 2. It is lower than or equal to `address`.
+ // 3. De-referencing this address may return
+ // a uint32_t value that contains the same int8_t
+ // value indicated by `address`.
+ //
+ // E.g.,
+ // address = b101010
+ // offset = b101010 & b000011 = b10 = 2
+ // (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
+ // which is (32-bit aligned).
+ uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
+ uint32_t old = *address_as_ui;
+ // E.g., offset = 2.
+ // address_as_ui is an address 2 bytes lower than `address`.
+ //
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ // ^ ^ ^
+ // | | |
+ // | address <--- offset * 8 (bit)-----> address_as_ui
+ // | ^
+ // | |
+ // ------------------------- *address_as_ui -----------------------
+ //
+ // This visualization shows
+ // 1. the 32-bit word at address_as_ui.
+ // 2. the gap between address_as_ui and address.
+ // 3. *address_as_ui contains the int8_t value at `address`.
+ uint32_t shift = offset * 8;
+ uint32_t old_byte;
+ uint32_t newval;
+ uint32_t assumed;
+ do {
+ assumed = old;
+ // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
+ // we want to select the 3rd byte (byte 2 below) from the word.
+ //
+ // Journey of a 32-bit value:
+ //
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // |
+ // | old >> offset * 8, where offset = 2.
+ // | Effectively, push lower two bytes
+ // | out of the word.
+ // V
+ //
+ // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
+ //
+ // | apply bit-wise AND,
+ // | & 0xff (i.e., & b11111111),
+ // | so that we only keep
+ // | the byte of interest.
+ // | Otherwise, overflow may
+ // | happen when casting this
+ // | 32-bit value to int8_t.
+ // V
+ //
+ // 00000000 | 00000000 | 00000000 | ..... byte 2 .....
+ old_byte = (old >> shift) & AtomicCasType::mask;
+ // Compute new int8_t value and store it to newrawvalue.
+ // Journey of a 32-bit value (cont'd):
+ //
+ // newrawvalue
+ // ... new byte 2 ...
+ auto newrawvalue = func(val, reinterpret_cast(old_byte));
+ // Put the new int8_t value back to 32-bit word.
+ // Also ensure that bits not occupied by the int8_t value are 0s.
+ //
+ // Journey of a 32-bit value (cont'd):
+ //
+ // reinterpret_cast(newrawvalue)
+ // random values | random values | random values | ... new byte 2 ...
+ //
+ // reinterpret_cast(newrawvalue) & AtomicCasType::mask
+ // 00000000 | 00000000 | 00000000 | ... new byte 2 ...
+ newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask;
+ // Journey of a 32-bit value (cont'd):
+ //
+ // old
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // 0x000000ff
+ // 00000000 | 00000000 | 00000000 | 11111111
+ //
+ // 0x000000ff << shift
+ // 00000000 | 11111111 | 00000000 | 00000000
+ //
+ // ~(0x000000ff << shift)
+ // 11111111 | 00000000 | 11111111 | 11111111
+ //
+ // old & ~(0x000000ff << shift)
+ // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // newval << shift
+ // 00000000 | ... new byte 2 ... | 00000000 | 00000000
+ //
+ // (old & ~(0x000000ff << shift)) | (newval << shift)
+ // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
+ newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift);
+ old = atomicCAS(address_as_ui, assumed, newval);
+ } while (assumed != old);
+}
+
+// It accumulates `val` into the `address` using the `func`.
+// This function is thread-safe (i.e., atomic).
+template
+__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
+ ValueType observed = *address, assumed, new_value;
+ using CasType = typename AtomicCasType::type;
+ static_assert(sizeof(ValueType) == sizeof(CasType),
+ "ValueType and CasType must have the same size for calling atomicCAS.");
+ auto address_as_cas_type = reinterpret_cast(address);
+ do {
+ // Record the value used to compute new value.
+ assumed = observed;
+
+ // Compute expected new value.
+ new_value = func(observed, val);
+
+ // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
+ // 4
+ // 8
+ auto observed_as_cas_type = *reinterpret_cast(&observed);
+ auto new_value_as_cas_type = *reinterpret_cast(&new_value);
+
+ // Call atomicCAS as if the 2-byte type variables are all unsigned short int.
+ // 4 unsigned int (or int)
+ // 8 unsigned long long int
+ auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
+
+ // Cast the freshly observed value in memory back to the TwoByteType.
+ observed = *reinterpret_cast(&cas_observed_as_cas_type);
+
+ // Two cases:
+ // 1. compare-and-swap success
+ // a. `address` holds `new_value`
+ // b. `observed` becomes the new value after the assignment.
+ // Thus, the following `observed != new_value` is false,
+ // and the loop terminates.
+ // 2. compare-and-swap fails
+ // a. `address` holds a value different from `observed`, thus,
+ // the `new_value` is stale.
+ // b. `observed` becomes the fresh value observed in `address`.
+ // Thus, the following (observed != new_value) is true,
+ // and the loop continues. In the next iteration, the
+ // `new_value` is computed again using the fresh `observed`.
+ } while (observed != assumed);
+}
+
+struct AddFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a + b;
+ }
+};
+
+struct MulFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a * b;
+ }
+};
+
+struct MaxFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return b > a ? b : a;
+ }
+};
+
+struct MinFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return b < a ? b : a;
+ }
+};
+
+__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, AddFunc());
+}
+__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MinFunc());
+}
+
+__device__ __forceinline__ void atomic_mul(half* address, half value) {
+#if __CUDA_ARCH__ >= 700
+ atomic_binary_func(address, value, MulFunc());
+#else
+ atomic_byte_func_with_unit32_cas(address, value, MulFunc());
+#endif
+}
+__device__ __forceinline__ void atomic_max(half* address, half value) {
+#if __CUDA_ARCH__ >= 700
+ atomic_binary_func(address, value, MaxFunc());
+#else
+ atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
+#endif
+}
+__device__ __forceinline__ void atomic_min(half* address, half value) {
+#if __CUDA_ARCH__ >= 700
+ atomic_binary_func(address, value, MinFunc());
+#else
+ atomic_byte_func_with_unit32_cas(address, value, MinFunc());
+#endif
+}
+
+__device__ __forceinline__ void atomic_mul(float* address, float value) {
+ atomic_binary_func(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(float* address, float value) {
+ atomic_binary_func(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(float* address, float value) {
+ atomic_binary_func(address, value, MinFunc());
+}
+
+__device__ __forceinline__ void atomic_mul(double* address, double value) {
+ atomic_binary_func(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(double* address, double value) {
+ atomic_binary_func(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(double* address, double value) {
+ atomic_binary_func(address, value, MinFunc());
+}
+
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 3fc4ed355a12b..77e682e05a2a4 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax);
@@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
@@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu
index 10c8625b39ef8..b710e8a1b48c2 100644
--- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu
@@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D {
template
struct FuncAssignment {
- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; }
+ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
+ start_addr[index] = value;
+ }
+};
+
+template
+struct FuncAdd {
+ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
+ atomic_add(start_addr + index, value);
+ }
+};
+
+template
+struct FuncMul {
+ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
+ atomic_mul(start_addr + index, value);
+ }
+};
+
+template
+struct FuncMax {
+ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
+ atomic_max(start_addr + index, value);
+ }
+};
+
+template
+struct FuncMin {
+ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const {
+ atomic_min(start_addr + index, value);
+ }
};
template
@@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con
template
Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data,
T* output_data, const GatherScatterElementsArgs& args) {
- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
- FuncAssignment());
+ if (args.operation == GatherScatterElementsArgs::Operation::NONE) {
+ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
+ FuncAssignment());
+ } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) {
+ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
+ FuncAdd());
+ } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) {
+ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
+ FuncMul());
+ } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) {
+ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
+ FuncMax());
+ } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) {
+ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args,
+ FuncMin());
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator.");
+ }
}
#define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \
diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h
index 631d0bf049c6f..7b1c88f1fc1cb 100644
--- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h
@@ -10,6 +10,14 @@ namespace onnxruntime {
namespace cuda {
struct GatherScatterElementsArgs {
+ enum class Operation {
+ NONE,
+ ADD,
+ MUL,
+ MAX,
+ MIN
+ };
+
int64_t rank;
int64_t axis;
int64_t input_size;
@@ -19,6 +27,9 @@ struct GatherScatterElementsArgs {
TArray indices_fdms;
TArray indices_strides;
int64_t indices_size;
+ // operation used to combine values associated the same
+ // memory location in the output tensor.
+ Operation operation;
};
template
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
index e4d145154971e..42a9f50001103 100755
--- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
@@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe
DataTypeImpl::GetTensorType()}),
ScatterElements);
-ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider,
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .TypeConstraint("Tind",
+ std::vector{DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()}),
+ ScatterElements);
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .TypeConstraint("Tind",
+ std::vector{DataTypeImpl::GetTensorType(),
+ DataTypeImpl::GetTensorType()}),
+ ScatterElements);
+
+ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(),
@@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector();
CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args);
+ if (reduction_ == "none") {
+ args.operation = GatherScatterElementsArgs::Operation::NONE;
+ } else if (reduction_ == "add") {
+ args.operation = GatherScatterElementsArgs::Operation::ADD;
+ } else if (reduction_ == "mul") {
+ args.operation = GatherScatterElementsArgs::Operation::MUL;
+ } else if (reduction_ == "min") {
+ args.operation = GatherScatterElementsArgs::Operation::MIN;
+ } else if (reduction_ == "max") {
+ args.operation = GatherScatterElementsArgs::Operation::MAX;
+ } else {
+ ORT_THROW("Unsupported reduction type");
+ }
+
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h
index 3e9e0ce041845..3884b716da308 100755
--- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h
@@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel {
ScatterElements(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(),
"Missing/Invalid 'axis' attribute value");
+ reduction_ = info.GetAttrOrDefault("reduction", "none");
+
+ ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" ||
+ reduction_ == "mul" || reduction_ == "max" ||
+ reduction_ == "min",
+ "Invalid reduction attribute value of ", reduction_);
}
~ScatterElements() = default;
Status ComputeInternal(OpKernelContext* context) const override;
@@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel {
struct ComputeImpl;
int64_t axis_;
+ // "reduction" attribute has been defined since opset 13 but
+ // we never implemented it. Let's try to support them starting
+ // with opset 18.
+ std::string reduction_;
};
} // namespace cuda
diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh
index 4e235702028c6..b5d01b91c70ed 100644
--- a/onnxruntime/core/providers/rocm/atomic/common.cuh
+++ b/onnxruntime/core/providers/rocm/atomic/common.cuh
@@ -59,5 +59,304 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz
atomic_add(start_addr + index, value);
}
+// Disable default template instantiation.
+// For every type T, we need to define a specialization
+// to select the right type for calling atomicCAS.
+template
+class AtomicCasType;
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned short int;
+ static const unsigned int mask = 0xffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned short int;
+ static const unsigned int mask = 0xffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned long long int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+template<>
+class AtomicCasType {
+ public:
+ using type = unsigned long long int;
+ static const unsigned int mask = 0xffffffffu;
+};
+
+// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh.
+//
+// This function compute 8-bit atomic binary operation using 32-bit atomicCAS.
+// It accumulate `val` into the `address` using the `func`.
+// The accumulation is atomic (i.e., thread-safe).
+//
+// E.g., Assume ValueType is
+// int8_t
+// and BinaryFunc is
+// struct AddFunc {
+// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const {
+// return a + b;
+// }
+// This function becomes atomic_add for int8_t.
+template
+__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) {
+ // Assert to ensure the following bit-wise manipulation is correct.
+ static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4,
+ "ValueType must be 1-byte, 2-byte or 4-byte large.");
+ // Number of bytes to the lower 4-byte aligned address.
+ // If the current address is b1010"10", then offset = b10 = 2,
+ // which means the current address is 2 bytes away from
+ // the lower 4-byte aligned address b1010"00".
+ size_t offset = (size_t)address & 3;
+ // Find an new 4-byte aligned address `address_as_ui` lower than
+ // or equal to `address`. Lower than `address` so that the actual
+ // int8_t byte is in the 4-byte word that we load.
+ //
+ // This address has the following properties:
+ // 1. It is 4-byte aligned.
+ // 2. It is lower than or equal to `address`.
+ // 3. De-referencing this address may return
+ // a uint32_t value that contains the same int8_t
+ // value indicated by `address`.
+ //
+ // E.g.,
+ // address = b101010
+ // offset = b101010 & b000011 = b10 = 2
+ // (char*)address - offset => (char*)b101010 - b000010 => b1010"00",
+ // which is (32-bit aligned).
+ uint32_t * address_as_ui = (uint32_t*)((char*)address - offset);
+ uint32_t old = *address_as_ui;
+ // E.g., offset = 2.
+ // address_as_ui is an address 2 bytes lower than `address`.
+ //
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ // ^ ^ ^
+ // | | |
+ // | address <--- offset * 8 (bit)-----> address_as_ui
+ // | ^
+ // | |
+ // ------------------------- *address_as_ui -----------------------
+ //
+ // This visualization shows
+ // 1. the 32-bit word at address_as_ui.
+ // 2. the gap between address_as_ui and address.
+ // 3. *address_as_ui contains the int8_t value at `address`.
+ uint32_t shift = offset * 8;
+ uint32_t old_byte;
+ uint32_t newval;
+ uint32_t assumed;
+ do {
+ assumed = old;
+ // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so
+ // we want to select the 3rd byte (byte 2 below) from the word.
+ //
+ // Journey of a 32-bit value:
+ //
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // |
+ // | old >> offset * 8, where offset = 2.
+ // | Effectively, push lower two bytes
+ // | out of the word.
+ // V
+ //
+ // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 .....
+ //
+ // | apply bit-wise AND,
+ // | & 0xff (i.e., & b11111111),
+ // | so that we only keep
+ // | the byte of interest.
+ // | Otherwise, overflow may
+ // | happen when casting this
+ // | 32-bit value to int8_t.
+ // V
+ //
+ // 00000000 | 00000000 | 00000000 | ..... byte 2 .....
+ old_byte = (old >> shift) & AtomicCasType::mask;
+ // Compute new int8_t value and store it to newrawvalue.
+ // Journey of a 32-bit value (cont'd):
+ //
+ // newrawvalue
+ // ... new byte 2 ...
+ auto newrawvalue = func(val, reinterpret_cast(old_byte));
+ // Put the new int8_t value back to 32-bit word.
+ // Also ensure that bits not occupied by the int8_t value are 0s.
+ //
+ // Journey of a 32-bit value (cont'd):
+ //
+ // reinterpret_cast(newrawvalue)
+ // random values | random values | random values | ... new byte 2 ...
+ //
+ // reinterpret_cast(newrawvalue) & AtomicCasType::mask
+ // 00000000 | 00000000 | 00000000 | ... new byte 2 ...
+ newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask;
+ // Journey of a 32-bit value (cont'd):
+ //
+ // old
+ // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // 0x000000ff
+ // 00000000 | 00000000 | 00000000 | 11111111
+ //
+ // 0x000000ff << shift
+ // 00000000 | 11111111 | 00000000 | 00000000
+ //
+ // ~(0x000000ff << shift)
+ // 11111111 | 00000000 | 11111111 | 11111111
+ //
+ // old & ~(0x000000ff << shift)
+ // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 .....
+ //
+ // newval << shift
+ // 00000000 | ... new byte 2 ... | 00000000 | 00000000
+ //
+ // (old & ~(0x000000ff << shift)) | (newval << shift)
+ // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 .....
+ newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift);
+ old = atomicCAS(address_as_ui, assumed, newval);
+ } while (assumed != old);
+}
+
+// It accumulates `val` into the `address` using the `func`.
+// This function is thread-safe (i.e., atomic).
+template
+__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) {
+ ValueType observed = *address, assumed, new_value;
+ using CasType = typename AtomicCasType::type;
+ static_assert(sizeof(ValueType) == sizeof(CasType),
+ "ValueType and CasType must have the same size for calling atomicCAS.");
+ auto address_as_cas_type = reinterpret_cast(address);
+ do {
+ // Record the value used to compute new value.
+ assumed = observed;
+
+ // Compute expected new value.
+ new_value = func(observed, val);
+
+ // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS.
+ // 4
+ // 8
+ auto observed_as_cas_type = *reinterpret_cast(&observed);
+ auto new_value_as_cas_type = *reinterpret_cast(&new_value);
+
+ // Call atomicCAS as if the 2-byte type variables are all unsigned short int.
+ // 4 unsigned int (or int)
+ // 8 unsigned long long int
+ auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type);
+
+ // Cast the freshly observed value in memory back to the TwoByteType.
+ observed = *reinterpret_cast(&cas_observed_as_cas_type);
+
+ // Two cases:
+ // 1. compare-and-swap success
+ // a. `address` holds `new_value`
+ // b. `observed` becomes the new value after the assignment.
+ // Thus, the following `observed != new_value` is false,
+ // and the loop terminates.
+ // 2. compare-and-swap fails
+ // a. `address` holds a value different from `observed`, thus,
+ // the `new_value` is stale.
+ // b. `observed` becomes the fresh value observed in `address`.
+ // Thus, the following (observed != new_value) is true,
+ // and the loop continues. In the next iteration, the
+ // `new_value` is computed again using the fresh `observed`.
+ } while (observed != assumed);
+}
+
+struct AddFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a + b;
+ }
+};
+
+struct MulFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return a * b;
+ }
+};
+
+struct MaxFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return b > a ? b : a;
+ }
+};
+
+struct MinFunc {
+ template
+ __device__ __forceinline__ T operator()(T a, T b) const {
+ return b < a ? b : a;
+ }
+};
+
+__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, AddFunc());
+}
+__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) {
+ atomic_byte_func_with_unit32_cas(address, value, MinFunc());
+}
+
+__device__ __forceinline__ void atomic_mul(half* address, half value) {
+ atomic_byte_func_with_unit32_cas(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(half* address, half value) {
+ atomic_byte_func_with_unit32_cas(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(half* address, half value) {
+ atomic_byte_func_with_unit32_cas(address, value, MinFunc());
+}
+
+__device__ __forceinline__ void atomic_mul(float* address, float value) {
+ atomic_binary_func(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(float* address, float value) {
+ atomic_binary_func(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(float* address, float value) {
+ atomic_binary_func(address, value, MinFunc());
+}
+
+__device__ __forceinline__ void atomic_mul(double* address, double value) {
+ atomic_binary_func(address, value, MulFunc());
+}
+__device__ __forceinline__ void atomic_max(double* address, double value) {
+ atomic_binary_func(address, value, MaxFunc());
+}
+__device__ __forceinline__ void atomic_min(double* address, double value) {
+ atomic_binary_func(address, value, MinFunc());
+}
+
+
} // namespace rocm
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index fff3d14b763d5..ee3578326ac6d 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax);
@@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
-
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split);
// Opset 19
@@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
-
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
// Opset 19
diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc
index 48f58add8237b..a5a165e150cf1 100644
--- a/onnxruntime/core/util/thread_utils.cc
+++ b/onnxruntime/core/util/thread_utils.cc
@@ -7,6 +7,7 @@
#ifdef _WIN32
#include
+#include
#endif
#include
#include "core/session/ort_apis.h"
@@ -98,7 +99,16 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) {
}
options.thread_pool_size = static_cast(default_affinities.size());
if (options.auto_set_affinity) {
+#ifdef _WIN32
+ // Only set thread affinity on Server with auto affinity.
+ // On client best to let OS scheduler handle.
+ // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage
+ if (IsWindowsServer()) {
+ to.affinities = std::move(default_affinities);
+ }
+#else
to.affinities = std::move(default_affinities);
+#endif
}
}
if (options.thread_pool_size <= 1) {
diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
index f470e9f6b6ed1..0bbcee12ea5cf 100644
--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
@@ -659,7 +659,12 @@ static bool CheckIfInputIsSequenceType(const std::string& name_input,
if (!temp) {
throw std::runtime_error("Corresponding type_proto is null");
} else {
- type_proto = *temp;
+ if (temp->has_optional_type()) {
+ const ::onnx::TypeProto_Optional& optional_type_proto = temp->optional_type();
+ type_proto = optional_type_proto.elem_type();
+ } else {
+ type_proto = *temp;
+ }
}
return type_proto.has_sequence_type();
diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
index 9b44bf400c05e..30e27bb15fa57 100644
--- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
@@ -302,5 +302,137 @@ TEST(Scatter, BoolInputWithAxis) {
scatter_bool_with_axis_tests("ScatterElements", 11);
}
+TEST(ScatterElements, AddReduction) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "add");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
+ test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, AddReductionAxis1) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 1);
+ test.AddAttribute("reduction", "add");
+
+ // update's slice shape is {2, 1}
+ test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
+ test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f});
+ test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MulReduction) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "mul");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MulReductionAxis1) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 1);
+ test.AddAttribute("reduction", "mul");
+
+ // update's slice shape is {2, 1}
+ test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f});
+ test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
+ test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MaxReduction_MLFloat16) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "max");
+
+ test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}));
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
+ test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}));
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MaxReduction_Float) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "max");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MaxReduction_Double) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "max");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MinReduction_MLFloat16) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "min");
+
+ test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}));
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f}));
+ test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}));
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MinReduction_Float) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "min");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterElements, MinReduction_Double) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "min");
+
+ test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f});
+ test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1});
+ test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f});
+ test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py
index e210917e7ad9a..68e441c87860e 100644
--- a/onnxruntime/test/python/onnxruntime_test_python.py
+++ b/onnxruntime/test/python/onnxruntime_test_python.py
@@ -650,6 +650,14 @@ def do_test_get_and_set_tuning_results(ep):
if "ROCMExecutionProvider" in onnxrt.get_available_providers():
do_test_get_and_set_tuning_results("ROCMExecutionProvider")
+ def test_run_model_with_optional_sequence_input(self):
+ sess = onnxrt.InferenceSession(get_name("identity_opt.onnx"))
+ x = [np.array([1, 2, 3, 4, 5]).astype(np.float32)]
+ input_name = sess.get_inputs()[0].name
+ output_name = sess.get_outputs()[0].name
+ res = sess.run([output_name], {input_name: x})
+ np.testing.assert_allclose(res[0], x)
+
def test_run_model(self):
sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers)
x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
diff --git a/onnxruntime/test/testdata/identity_opt.onnx b/onnxruntime/test/testdata/identity_opt.onnx
new file mode 100644
index 0000000000000..24c05f7b7227f
Binary files /dev/null and b/onnxruntime/test/testdata/identity_opt.onnx differ