diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9d9b266355335..2ea557b7d61fe 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -744,7 +744,9 @@ Do not modify directly.* |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 8844b7e7a26c4..c7a2005924836 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -198,13 +198,6 @@ struct Func_Min { } }; -template <> -struct Func_Min { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); - } -}; - template <> struct Func_Min { void operator()(BFloat16*, const BFloat16*) const { @@ -233,13 +226,6 @@ struct Func_Max { } }; -template <> -struct Func_Max { - void operator()(MLFloat16*, const MLFloat16*) const { - ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); - } -}; - template <> struct Func_Max { void operator()(BFloat16*, const BFloat16*) const { diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh index 14fa2d0706f73..170aa3a2d8d0c 100644 --- a/onnxruntime/core/providers/cuda/atomic/common.cuh +++ b/onnxruntime/core/providers/cuda/atomic/common.cuh @@ -122,5 +122,316 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, #endif } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, assumed, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != assumed); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MulFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +#endif +} +__device__ __forceinline__ void atomic_max(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MaxFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +#endif +} +__device__ __forceinline__ void atomic_min(half* address, half value) { +#if __CUDA_ARCH__ >= 700 + atomic_binary_func(address, value, MinFunc()); +#else + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +#endif +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3fc4ed355a12b..77e682e05a2a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1046,7 +1046,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1254,6 +1254,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1269,6 +1270,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); @@ -1937,7 +1939,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2138,6 +2140,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2159,6 +2162,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index 10c8625b39ef8..b710e8a1b48c2 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -95,7 +95,37 @@ struct OffsetCalculatorFor2D { template struct FuncAssignment { - __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + start_addr[index] = value; + } +}; + +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_add(start_addr + index, value); + } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_mul(start_addr + index, value); + } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_max(start_addr + index, value); + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { + atomic_min(start_addr + index, value); + } }; template @@ -238,8 +268,24 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con template Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, T* output_data, const GatherScatterElementsArgs& args) { - return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, - FuncAssignment()); + if (args.operation == GatherScatterElementsArgs::Operation::NONE) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAssignment()); + } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncAdd()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMul()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMax()); + } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { + return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, + FuncMin()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); + } } #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h index 631d0bf049c6f..7b1c88f1fc1cb 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h @@ -10,6 +10,14 @@ namespace onnxruntime { namespace cuda { struct GatherScatterElementsArgs { + enum class Operation { + NONE, + ADD, + MUL, + MAX, + MIN + }; + int64_t rank; int64_t axis; int64_t input_size; @@ -19,6 +27,9 @@ struct GatherScatterElementsArgs { TArray indices_fdms; TArray indices_strides; int64_t indices_size; + // operation used to combine values associated the same + // memory location in the output tensor. + Operation operation; }; template diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index e4d145154971e..42a9f50001103 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -27,7 +27,23 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe DataTypeImpl::GetTensorType()}), ScatterElements); -ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", + std::vector{DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + ScatterElements); + +ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), @@ -106,6 +122,20 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + if (reduction_ == "none") { + args.operation = GatherScatterElementsArgs::Operation::NONE; + } else if (reduction_ == "add") { + args.operation = GatherScatterElementsArgs::Operation::ADD; + } else if (reduction_ == "mul") { + args.operation = GatherScatterElementsArgs::Operation::MUL; + } else if (reduction_ == "min") { + args.operation = GatherScatterElementsArgs::Operation::MIN; + } else if (reduction_ == "max") { + args.operation = GatherScatterElementsArgs::Operation::MAX; + } else { + ORT_THROW("Unsupported reduction type"); + } + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. int dtype = GetElementType(input_tensor->DataType()->Size()); if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h index 3e9e0ce041845..3884b716da308 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h @@ -14,6 +14,12 @@ class ScatterElements final : public CudaKernel { ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + reduction_ = info.GetAttrOrDefault("reduction", "none"); + + ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || + reduction_ == "mul" || reduction_ == "max" || + reduction_ == "min", + "Invalid reduction attribute value of ", reduction_); } ~ScatterElements() = default; Status ComputeInternal(OpKernelContext* context) const override; @@ -23,6 +29,10 @@ class ScatterElements final : public CudaKernel { struct ComputeImpl; int64_t axis_; + // "reduction" attribute has been defined since opset 13 but + // we never implemented it. Let's try to support them starting + // with opset 18. + std::string reduction_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh index 4e235702028c6..b5d01b91c70ed 100644 --- a/onnxruntime/core/providers/rocm/atomic/common.cuh +++ b/onnxruntime/core/providers/rocm/atomic/common.cuh @@ -59,5 +59,304 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz atomic_add(start_addr + index, value); } +// Disable default template instantiation. +// For every type T, we need to define a specialization +// to select the right type for calling atomicCAS. +template +class AtomicCasType; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned short int; + static const unsigned int mask = 0xffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = int; + static const unsigned int mask = 0xffffffffu; +}; + +template<> +class AtomicCasType { + public: + using type = unsigned long long int; + static const unsigned int mask = 0xffffffffu; +}; + +// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +// +// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +// It accumulate `val` into the `address` using the `func`. +// The accumulation is atomic (i.e., thread-safe). +// +// E.g., Assume ValueType is +// int8_t +// and BinaryFunc is +// struct AddFunc { +// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +// return a + b; +// } +// This function becomes atomic_add for int8_t. +template +__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { + // Assert to ensure the following bit-wise manipulation is correct. + static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, + "ValueType must be 1-byte, 2-byte or 4-byte large."); + // Number of bytes to the lower 4-byte aligned address. + // If the current address is b1010"10", then offset = b10 = 2, + // which means the current address is 2 bytes away from + // the lower 4-byte aligned address b1010"00". + size_t offset = (size_t)address & 3; + // Find an new 4-byte aligned address `address_as_ui` lower than + // or equal to `address`. Lower than `address` so that the actual + // int8_t byte is in the 4-byte word that we load. + // + // This address has the following properties: + // 1. It is 4-byte aligned. + // 2. It is lower than or equal to `address`. + // 3. De-referencing this address may return + // a uint32_t value that contains the same int8_t + // value indicated by `address`. + // + // E.g., + // address = b101010 + // offset = b101010 & b000011 = b10 = 2 + // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", + // which is (32-bit aligned). + uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); + uint32_t old = *address_as_ui; + // E.g., offset = 2. + // address_as_ui is an address 2 bytes lower than `address`. + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // ^ ^ ^ + // | | | + // | address <--- offset * 8 (bit)-----> address_as_ui + // | ^ + // | | + // ------------------------- *address_as_ui ----------------------- + // + // This visualization shows + // 1. the 32-bit word at address_as_ui. + // 2. the gap between address_as_ui and address. + // 3. *address_as_ui contains the int8_t value at `address`. + uint32_t shift = offset * 8; + uint32_t old_byte; + uint32_t newval; + uint32_t assumed; + do { + assumed = old; + // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so + // we want to select the 3rd byte (byte 2 below) from the word. + // + // Journey of a 32-bit value: + // + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // | + // | old >> offset * 8, where offset = 2. + // | Effectively, push lower two bytes + // | out of the word. + // V + // + // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... + // + // | apply bit-wise AND, + // | & 0xff (i.e., & b11111111), + // | so that we only keep + // | the byte of interest. + // | Otherwise, overflow may + // | happen when casting this + // | 32-bit value to int8_t. + // V + // + // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... + old_byte = (old >> shift) & AtomicCasType::mask; + // Compute new int8_t value and store it to newrawvalue. + // Journey of a 32-bit value (cont'd): + // + // newrawvalue + // ... new byte 2 ... + auto newrawvalue = func(val, reinterpret_cast(old_byte)); + // Put the new int8_t value back to 32-bit word. + // Also ensure that bits not occupied by the int8_t value are 0s. + // + // Journey of a 32-bit value (cont'd): + // + // reinterpret_cast(newrawvalue) + // random values | random values | random values | ... new byte 2 ... + // + // reinterpret_cast(newrawvalue) & AtomicCasType::mask + // 00000000 | 00000000 | 00000000 | ... new byte 2 ... + newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; + // Journey of a 32-bit value (cont'd): + // + // old + // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... + // + // 0x000000ff + // 00000000 | 00000000 | 00000000 | 11111111 + // + // 0x000000ff << shift + // 00000000 | 11111111 | 00000000 | 00000000 + // + // ~(0x000000ff << shift) + // 11111111 | 00000000 | 11111111 | 11111111 + // + // old & ~(0x000000ff << shift) + // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... + // + // newval << shift + // 00000000 | ... new byte 2 ... | 00000000 | 00000000 + // + // (old & ~(0x000000ff << shift)) | (newval << shift) + // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... + newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); + old = atomicCAS(address_as_ui, assumed, newval); + } while (assumed != old); +} + +// It accumulates `val` into the `address` using the `func`. +// This function is thread-safe (i.e., atomic). +template +__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { + ValueType observed = *address, assumed, new_value; + using CasType = typename AtomicCasType::type; + static_assert(sizeof(ValueType) == sizeof(CasType), + "ValueType and CasType must have the same size for calling atomicCAS."); + auto address_as_cas_type = reinterpret_cast(address); + do { + // Record the value used to compute new value. + assumed = observed; + + // Compute expected new value. + new_value = func(observed, val); + + // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. + // 4 + // 8 + auto observed_as_cas_type = *reinterpret_cast(&observed); + auto new_value_as_cas_type = *reinterpret_cast(&new_value); + + // Call atomicCAS as if the 2-byte type variables are all unsigned short int. + // 4 unsigned int (or int) + // 8 unsigned long long int + auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); + + // Cast the freshly observed value in memory back to the TwoByteType. + observed = *reinterpret_cast(&cas_observed_as_cas_type); + + // Two cases: + // 1. compare-and-swap success + // a. `address` holds `new_value` + // b. `observed` becomes the new value after the assignment. + // Thus, the following `observed != new_value` is false, + // and the loop terminates. + // 2. compare-and-swap fails + // a. `address` holds a value different from `observed`, thus, + // the `new_value` is stale. + // b. `observed` becomes the fresh value observed in `address`. + // Thus, the following (observed != new_value) is true, + // and the loop continues. In the next iteration, the + // `new_value` is computed again using the fresh `observed`. + } while (observed != assumed); +} + +struct AddFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +struct MulFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return a * b; + } +}; + +struct MaxFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b > a ? b : a; + } +}; + +struct MinFunc { + template + __device__ __forceinline__ T operator()(T a, T b) const { + return b < a ? b : a; + } +}; + +__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +} +__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(half* address, half value) { + atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(float* address, float value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(float* address, float value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(float* address, float value) { + atomic_binary_func(address, value, MinFunc()); +} + +__device__ __forceinline__ void atomic_mul(double* address, double value) { + atomic_binary_func(address, value, MulFunc()); +} +__device__ __forceinline__ void atomic_max(double* address, double value) { + atomic_binary_func(address, value, MaxFunc()); +} +__device__ __forceinline__ void atomic_min(double* address, double value) { + atomic_binary_func(address, value, MinFunc()); +} + + } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index fff3d14b763d5..ee3578326ac6d 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1069,7 +1069,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax); @@ -1290,6 +1290,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1302,7 +1303,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); - +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2004,7 +2005,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2225,6 +2226,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2237,7 +2239,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index 48f58add8237b..a5a165e150cf1 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -7,6 +7,7 @@ #ifdef _WIN32 #include +#include #endif #include #include "core/session/ort_apis.h" @@ -98,7 +99,16 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { } options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { +#ifdef _WIN32 + // Only set thread affinity on Server with auto affinity. + // On client best to let OS scheduler handle. + // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage + if (IsWindowsServer()) { + to.affinities = std::move(default_affinities); + } +#else to.affinities = std::move(default_affinities); +#endif } } if (options.thread_pool_size <= 1) { diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index f470e9f6b6ed1..0bbcee12ea5cf 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -659,7 +659,12 @@ static bool CheckIfInputIsSequenceType(const std::string& name_input, if (!temp) { throw std::runtime_error("Corresponding type_proto is null"); } else { - type_proto = *temp; + if (temp->has_optional_type()) { + const ::onnx::TypeProto_Optional& optional_type_proto = temp->optional_type(); + type_proto = optional_type_proto.elem_type(); + } else { + type_proto = *temp; + } } return type_proto.has_sequence_type(); diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 9b44bf400c05e..30e27bb15fa57 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -302,5 +302,137 @@ TEST(Scatter, BoolInputWithAxis) { scatter_bool_with_axis_tests("ScatterElements", 11); } +TEST(ScatterElements, AddReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "add"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, AddReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "add"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MulReduction) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "mul"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MulReductionAxis1) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 1); + test.AddAttribute("reduction", "mul"); + + // update's slice shape is {2, 1} + test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); + test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_Float) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MaxReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "max"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f})); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); + test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f})); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_Float) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterElements, MinReduction_Double) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "min"); + + test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); + test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); + test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); + test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e210917e7ad9a..68e441c87860e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -650,6 +650,14 @@ def do_test_get_and_set_tuning_results(ep): if "ROCMExecutionProvider" in onnxrt.get_available_providers(): do_test_get_and_set_tuning_results("ROCMExecutionProvider") + def test_run_model_with_optional_sequence_input(self): + sess = onnxrt.InferenceSession(get_name("identity_opt.onnx")) + x = [np.array([1, 2, 3, 4, 5]).astype(np.float32)] + input_name = sess.get_inputs()[0].name + output_name = sess.get_outputs()[0].name + res = sess.run([output_name], {input_name: x}) + np.testing.assert_allclose(res[0], x) + def test_run_model(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) diff --git a/onnxruntime/test/testdata/identity_opt.onnx b/onnxruntime/test/testdata/identity_opt.onnx new file mode 100644 index 0000000000000..24c05f7b7227f Binary files /dev/null and b/onnxruntime/test/testdata/identity_opt.onnx differ