diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8fa67ee172733..635f65696eae2 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -774,7 +774,9 @@ Do not modify directly.* |||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| -|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3b1698773b85b..652626ce9e241 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1157,7 +1157,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); @@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 @@ -1312,6 +1313,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); @@ -2071,7 +2073,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2202,6 +2204,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 17 @@ -2225,6 +2228,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc index 42a9f50001103..bfe385af49dc4 100755 --- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc @@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { } else if (reduction_ == "max") { args.operation = GatherScatterElementsArgs::Operation::MAX; } else { - ORT_THROW("Unsupported reduction type"); + ORT_THROW("Unsupported reduction type for ScatterElements."); } // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc index 6191715f79188..a270249da2b7f 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc @@ -3,6 +3,7 @@ #include "core/providers/cuda/tensor/scatter_nd.h" #include "core/providers/cuda/tensor/scatter_nd_impl.h" +#include "core/providers/cuda/tensor/scatter_nd_common.h" #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/providers/cpu/tensor/utils.h" @@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .MayInplace(0, 0), - ScatterND); + ScatterNDDisjointAndNoReduction); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND, + kOnnxDomain, + 13, 15, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .MayInplace(0, 0), + ScatterNDWithAtomicReduction); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND, + kOnnxDomain, + 16, 17, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .MayInplace(0, 0), + ScatterNDWithAtomicReduction); ONNX_OPERATOR_KERNEL_EX(ScatterND, kOnnxDomain, - 13, + 18, kCudaExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) .MayInplace(0, 0), - ScatterND); + ScatterNDWithAtomicReduction); -Status ScatterND::ComputeInternal(OpKernelContext* context) const { +static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape, + ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, + CudaKernel::CudaAsyncBuffer& element_counts_and_input_dims_gpu, + onnxruntime::OpKernelContext* context) { + TensorPitches input_strides(input_shape); + + if (last_index_dimension < 6) { + element_counts_and_input_dims.gpu_ptr = nullptr; + for (int64_t i = 0; i < last_index_dimension; ++i) { + element_counts_and_input_dims.stack_ptr[i] = input_strides[i]; + element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i]; + } + } else { + element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2); + memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2); + for (int64_t i = 0; i < last_index_dimension; ++i) { + element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i]; + element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i]; + } + ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream())); + element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr(); + } + return Status::OK(); +} + +Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const { const auto* input_tensor = context->Input(0); const auto* indices_tensor = context->Input(1); const auto* updates_tensor = context->Input(2); @@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const { const void* input_data = input_tensor->DataRaw(); void* output_data = output_tensor->MutableDataRaw(); - size_t element_size = input_tensor->DataType()->Size(); - if (input_data != output_data) { // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ? CUDA_RETURN_IF_ERROR( @@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const { } auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1]; + size_t element_size = input_tensor->DataType()->Size(); // We need element counts for each dimension and the input dim value for each dimension // for the range [0, last_index_dimension). // To avoid multiple GPU data transfers, we combine this into one array and send it through - TensorPitches input_strides(input_shape); - std::vector element_counts_and_input_dims(last_index_dimension * 2, 0LL); - for (int64_t i = 0; i < last_index_dimension; ++i) { - element_counts_and_input_dims[i] = input_strides[i]; - element_counts_and_input_dims[i + last_index_dimension] = input_shape[i]; - } - CudaAsyncBuffer element_counts_and_input_dims_gpu(this, element_counts_and_input_dims); - ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream())); + ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims; + CudaAsyncBuffer element_counts_and_input_dims_gpu(this); + ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, + element_counts_and_input_dims, + element_counts_and_input_dims_gpu, + context)); ORT_RETURN_IF_ERROR(ScatterNDImpl( Stream(context), @@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const { indices_shape.Size() / static_cast(last_index_dimension), indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec last_index_dimension, - element_counts_and_input_dims_gpu.GpuPtr(), + element_counts_and_input_dims, updates_tensor->DataRaw(), input_shape.SizeFromDimension(last_index_dimension))); return Status::OK(); } +Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const { + const auto* input_tensor = context->Input(0); + const auto* indices_tensor = context->Input(1); + const auto* updates_tensor = context->Input(2); + + const auto& input_shape = input_tensor->Shape(); + const auto& indices_shape = indices_tensor->Shape(); + const auto& updates_shape = updates_tensor->Shape(); + + // Validate input shapes + ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape)); + + auto* output_tensor = context->Output(0, input_shape); + + const void* input_data = input_tensor->DataRaw(); + void* output_data = output_tensor->MutableDataRaw(); + + if (input_data != output_data) { + // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will + // be faster than invoking cudaMemcpy ? + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream(context))); + } + + // Bail out early + if (indices_shape.Size() == 0) { + return Status::OK(); + } + + auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1]; + ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims; + CudaAsyncBuffer element_counts_and_input_dims_gpu(this); + ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape, + element_counts_and_input_dims, + element_counts_and_input_dims_gpu, + context)); + + switch (reduction_) { + case ScatterNDReduction::None: { + size_t element_size = input_tensor->DataType()->Size(); + ORT_RETURN_IF_ERROR(ScatterNDImpl( + Stream(context), + output_data, + element_size, + indices_shape.Size() / static_cast(last_index_dimension), + indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec + last_index_dimension, + element_counts_and_input_dims, + updates_tensor->DataRaw(), + input_shape.SizeFromDimension(last_index_dimension))); + } break; + case ScatterNDReduction::Add: + case ScatterNDReduction::Min: + case ScatterNDReduction::Max: + case ScatterNDReduction::Mul: { + auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType(); + ORT_RETURN_IF_ERROR(ScatterNDImplReduction( + Stream(context), + output_data, + element_type, + indices_shape.Size() / static_cast(last_index_dimension), + indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec + last_index_dimension, + element_counts_and_input_dims, + updates_tensor->DataRaw(), + input_shape.SizeFromDimension(last_index_dimension), + reduction_)); + } break; + default: + ORT_THROW("ScatterND not supported for other reduction than Add, None."); + break; + } + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h index 07df5ab552c3c..6d8bbe6f463fd 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h @@ -3,18 +3,63 @@ #pragma once +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/tensor/scatter_nd_kind.h" #include "core/providers/cpu/tensor/scatter_nd.h" namespace onnxruntime { namespace cuda { -class ScatterND final : public CudaKernel { +/** + * This implementation assumes there is common indices and + * reduction is not needed. The code does not check that condition. + * However in that case, the same output element could be accessed + * from different threads at the same time and the final value + * is unlikely to be correct. + */ +class ScatterNDDisjointAndNoReduction final : public CudaKernel { public: - explicit ScatterND(const OpKernelInfo& info) : CudaKernel(info) {} + explicit ScatterNDDisjointAndNoReduction(const OpKernelInfo& info) : CudaKernel(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; +/** + * This is an implementation derived from the first one. + * It does atomic operation to handle conflicts. + * The result is unlikely to be correct if the reduction is none + * as there is no guarantee that the final value will be the one + * corresponding to the highest visited index. + * TODO: change the implementation of avoid conflicts. + */ +class ScatterNDWithAtomicReduction final : public CudaKernel { + public: + explicit ScatterNDWithAtomicReduction(const OpKernelInfo& info) : CudaKernel(info) { + std::string reduction; + + if (info.GetAttr("reduction", &reduction).IsOK()) { + if (reduction == "add") { + reduction_ = ScatterNDReduction::Add; + } else if (reduction == "mul") { + reduction_ = ScatterNDReduction::Mul; + } else if (reduction == "min") { + reduction_ = ScatterNDReduction::Min; + } else if (reduction == "max") { + reduction_ = ScatterNDReduction::Max; + } else if (reduction == "none") { + LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees " + << "to be correct if indices are not duplicated."; + } else { + ORT_THROW("Reduction '", reduction, "' is not supported on CUDA and opset >= 13."); + } + } + } + Status ComputeInternal(OpKernelContext* context) const override; + + private: + ScatterNDReduction reduction_{ScatterNDReduction::None}; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h new file mode 100644 index 0000000000000..9f1465590c5e4 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace cuda { + +struct ElementCountsAndInputDimsSpanOrGpu { + int64_t stack_ptr[12]; + int64_t* gpu_ptr; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu index e9199b5e1b15b..47e7d103ce27b 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu @@ -14,7 +14,7 @@ __global__ void _ScatterNDKernel( const size_t num_indices, const int64_t* indices_data, const int64_t last_index_dimension, - const int64_t* element_counts_and_input_dims, + ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims, const T* updates_data, const size_t num_updates_elements) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices); @@ -27,8 +27,12 @@ __global__ void _ScatterNDKernel( for (size_t i = indices_start; i < indices_end; ++i) { int64_t index = indices_data[i]; - int64_t element_count_dim = element_counts_and_input_dims[i - indices_start]; - int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension]; + int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr + ? element_counts_and_input_dims.stack_ptr[i - indices_start] + : element_counts_and_input_dims.gpu_ptr[i - indices_start]; + int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr + ? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension] + : element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension]; // Clamp the index if out of range // This would have been an error in the CPU kernel, but throwing in the CUDA EP @@ -66,7 +70,7 @@ Status ScatterNDImpl( const size_t num_indices, const int64_t* indices_data, const int64_t last_index_dimension, - const int64_t* element_counts_and_input_dims, + const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, const void* updates_data, const size_t num_updates_elements) { if (num_indices == 0) @@ -128,5 +132,197 @@ Status ScatterNDImpl( return Status::OK(); } +template +struct FuncAdd { + __device__ __inline__ void operator()(T* start_addr, T value) const { + atomic_add(start_addr, value); + } +}; + +template +struct FuncMul { + __device__ __inline__ void operator()(T* start_addr, T value) const { + atomic_mul(start_addr, value); + } +}; + +template +struct FuncMax { + __device__ __inline__ void operator()(T* start_addr, T value) const { + atomic_max(start_addr, value); + } +}; + +template +struct FuncMin { + __device__ __inline__ void operator()(T* start_addr, T value) const { + atomic_min(start_addr, value); + } +}; + +template +__global__ void _ScatterNDKernelReduction( + T* output_data, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims, + const T* updates_data, + const size_t num_updates_elements, + const TFunc func) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices); + + // Compute the base offset into the output data + int64_t data_offset = 0; + + size_t indices_start = last_index_dimension * id; + size_t indices_end = indices_start + last_index_dimension; + for (size_t i = indices_start; i < indices_end; ++i) { + int64_t index = indices_data[i]; + + int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr + ? element_counts_and_input_dims.stack_ptr[i - indices_start] + : element_counts_and_input_dims.gpu_ptr[i - indices_start]; + int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr + ? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension] + : element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension]; + + // Clamp the index if out of range + // This would have been an error in the CPU kernel, but throwing in the CUDA EP + // is hard. This is the approach taken by other frameworks for out of bound indices + // in their corresponding GPU backends as well. + // index >= -dim_value && index < dim_value + + if (index >= 0) { + if (index >= dim_value) { + index = dim_value - 1; + } + } else { + if (index < -dim_value) { + index = 0; + } else { + index += dim_value; + } + } + + data_offset += (index * element_count_dim); + } + + const T* updates_data_base = updates_data + num_updates_elements * id; + T* output_data_base = output_data + data_offset; + + for (size_t i = 0; i < num_updates_elements; ++i) { + func(output_data_base + i, updates_data_base[i]); + } +} + +template +Status _ScatterNDType( + cudaStream_t stream, + T* output_data, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, + const T* updates_data, + const size_t num_updates_elements, + ScatterNDReduction reduction) { + // Parallelize on number of indices + int blocksPerGrid = static_cast(ceil(static_cast(num_indices) / GridDim::maxThreadsPerBlock)); + + switch (reduction) { + case ScatterNDReduction::Add: + _ScatterNDKernelReduction<<>>( + output_data, + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + updates_data, + num_updates_elements, + FuncAdd()); + break; + case ScatterNDReduction::Mul: + _ScatterNDKernelReduction<<>>( + output_data, + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + updates_data, + num_updates_elements, + FuncMul()); + break; + case ScatterNDReduction::Min: + _ScatterNDKernelReduction<<>>( + output_data, + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + updates_data, + num_updates_elements, + FuncMin()); + break; + case ScatterNDReduction::Max: + _ScatterNDKernelReduction<<>>( + output_data, + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + updates_data, + num_updates_elements, + FuncMax()); + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Reduction ", static_cast(reduction), " not implemented for ScatterND operator."); + } + + return Status::OK(); +} + +Status ScatterNDImplReduction( + cudaStream_t stream, + void* output_data, + const int32_t element_type, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, + const void* updates_data, + const size_t num_updates_elements, + ScatterNDReduction reduction) { + if (num_indices == 0) + return Status::OK(); + + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return _ScatterNDType( + stream, + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements, + reduction); + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return _ScatterNDType( + stream, + reinterpret_cast(output_data), + num_indices, + indices_data, + last_index_dimension, + element_counts_and_input_dims, + reinterpret_cast(updates_data), + num_updates_elements, + reduction); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "element_type ", static_cast(element_type), " not implemented for ScatterND operator."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h index 874d275f94776..a3c8aab460043 100644 --- a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h @@ -4,6 +4,8 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include "core/providers/cuda/tensor/scatter_nd_kind.h" +#include "core/providers/cuda/tensor/scatter_nd_common.h" namespace onnxruntime { namespace cuda { @@ -15,9 +17,21 @@ Status ScatterNDImpl( const size_t num_indices, const int64_t* indices_data, const int64_t last_index_dimension, - const int64_t* element_counts_and_input_dims, + const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, const void* updates_data, const size_t num_updates_elements); +Status ScatterNDImplReduction( + cudaStream_t stream, + void* output_data, + const int32_t element_type, + const size_t num_indices, + const int64_t* indices_data, + const int64_t last_index_dimension, + const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims, + const void* updates_data, + const size_t num_updates_elements, + ScatterNDReduction reduction); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h new file mode 100644 index 0000000000000..d766cdd920955 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace cuda { + +enum class ScatterNDReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 87daaeea969ac..4b0fd783deeac 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1161,7 +1161,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); @@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1308,6 +1309,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); @@ -2115,7 +2117,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2249,6 +2251,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, @@ -2262,6 +2265,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.AddInput("indices", {3, 1}, {0, 1, 0}); + // The linter complains if the line is split into multiple lines. + test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f}); + test1.AddOutput("output", {2, 2, 3}, {8194.1f, 16388.1f, 32776.10f, 65552.10f, 131104.1f, 262208.1f, 128.1f, 256.1f, 512.1f, 1024.1f, 2048.1f, 4096.1f}); + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterNDOpTest, ScatterND_18_mul) { + OpTester test1("ScatterND", 18); + test1.AddAttribute("reduction", "mul"); + test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.AddInput("indices", {3, 1}, {0, 1, 0}); + // The linter complains if the line is split into multiple lines. + test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f}); + test1.AddOutput("output", {2, 2, 3}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f}); + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterNDOpTest, ScatterND_18_mul_long_shape) { + OpTester test1("ScatterND", 18); + test1.AddAttribute("reduction", "mul"); + test1.AddInput("data", {2, 2, 3, 1, 1, 1, 1}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.AddInput("indices", {3, 1}, {0, 1, 0}); + // The linter complains if the line is split into multiple lines. + test1.AddInput("updates", {3, 2, 3, 1, 1, 1, 1}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f}); + test1.AddOutput("output", {2, 2, 3, 1, 1, 1, 1}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f}); + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterNDOpTest, ScatterND_18_min) { + OpTester test1("ScatterND", 18); + test1.AddAttribute("reduction", "min"); + test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.AddInput("indices", {3, 1}, {0, 1, 0}); + // The linter complains if the line is split into multiple lines. + test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f}); + test1.AddOutput("output", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +TEST(ScatterNDOpTest, ScatterND_18_max) { + OpTester test1("ScatterND", 18); + test1.AddAttribute("reduction", "max"); + test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f}); + test1.AddInput("indices", {3, 1}, {0, 1, 0}); + // The linter complains if the line is split into multiple lines. + test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f}); + test1.AddOutput("output", {2, 2, 3}, {8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0}); + test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 6b587be7d74eb..2a7a7158b5f62 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -308,13 +308,30 @@ TEST(ScatterElements, AddReduction) { test.AddAttribute("axis", 0); test.AddAttribute("reduction", "add"); - test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); + test.AddInput("data", {3, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + test.AddInput("indices", {2, 3}, {1, 0, 2, 0, 2, 1}); + test.AddInput("updates", {2, 3}, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f}); + test.AddOutput("y", {3, 3}, {3.0f, 1.1f, 0.0f, 1.0f, 0.0f, 2.2f, 0.0f, 2.1f, 1.2f}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +} + +#if defined(CUDA_VERSION) +// Operation on float16 (MLFloat16) is not implemented on CPU. +TEST(ScatterElements, AddReduction_MLFloat16) { + OpTester test("ScatterElements", 18); + test.AddAttribute("axis", 0); + test.AddAttribute("reduction", "add"); + + test.AddInput("data", {2, 3}, ToFloat16(std::vector({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}))); test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); - test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); + test.AddInput("updates", {4, 3}, ToFloat16(std::vector({1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}))); + test.AddOutput("y", {2, 3}, ToFloat16(std::vector({-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}))); + // exclude CPU Execution Provider as MLFloat16 is not supported in CPU test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } +#endif TEST(ScatterElements, AddReductionAxis1) { OpTester test("ScatterElements", 18); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 395315b2a2b0c..6eebc996fde9c 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -89,14 +89,16 @@ def apply_filters(filters, category): def load_jsonc(basename: str): """Returns a deserialized object from the JSONC file in testdata/.""" - filename = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "testdata", - basename, - ) - if not os.path.exists(filename): - raise FileNotFoundError(f"File not found {filename!r}.") + filenames = [ + os.path.join(os.path.dirname(os.path.realpath(__file__)), "testdata", basename), + os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "test", "testdata", basename)), + ] + + filtered = [f for f in filenames if os.path.exists(f)] + if not filtered: + raise FileNotFoundError(f"No file found in {filenames!r}.") + filename = filtered[0] with open(filename, encoding="utf-8") as f: # pylint: disable=invalid-name lines = f.readlines() lines = [x.split("//")[0] for x in lines] diff --git a/onnxruntime/test/python/onnxruntime_test_scatternd.py b/onnxruntime/test/python/onnxruntime_test_scatternd.py new file mode 100644 index 0000000000000..2a5555bba37de --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_scatternd.py @@ -0,0 +1,329 @@ +import itertools +import json +import os +import typing +import unittest +import warnings + +import numpy as np +import onnx.helper as oh +from onnx import TensorProto, load +from onnx.numpy_helper import from_array +from onnx.reference import ReferenceEvaluator + +import onnxruntime + + +def has_cuda(): + available_providers = [provider for provider in onnxruntime.get_available_providers()] + return "CUDAExecutionProvider" in available_providers + + +def ignore_warnings(warns: typing.List[Warning]) -> typing.Callable: + def wrapper(fct): + if warns is None: + raise AssertionError(f"warns cannot be None for '{fct}'.") + + def call_f(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", warns) + return fct(self) + + return call_f + + return wrapper + + +class TestScatterPerProvider(unittest.TestCase): + def assert_exists(self, filename: str): + assert os.path.exists(filename), f"Unable to find {filename!r}." + + def common_scatter(self, opset, providers, dtype, reduction, expected_names): + from onnxruntime import InferenceSession, SessionOptions + + op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND" + ndim = 2 if op_type == "ScatterElements" else 3 + + assert dtype in (np.float16, np.float32) + itype = TensorProto.FLOAT if dtype == np.float32 else TensorProto.FLOAT16 + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("CastLike", ["X", "I"], ["data"]), + oh.make_node( + op_type, + inputs=["data", "indices", "updates"], + outputs=["sy"], + # axis=0, + reduction=reduction, + ), + oh.make_node("Sub", ["sy", "I"], ["Y"]), + ], + "name", + [ + oh.make_tensor_value_info("X", TensorProto.FLOAT, [None] * ndim), + oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]), + oh.make_tensor_value_info("updates", itype, [None] * ndim), + ], + [oh.make_tensor_value_info("Y", itype, [None] * ndim)], + [from_array(np.array([0], dtype=dtype), name="I")], + ), + opset_imports=[oh.make_opsetid("", opset)], + ir_version=8 if opset <= 18 else 9, + ) + + if not os.path.exists("temp_dump"): + os.mkdir("temp_dump") + for name in os.listdir("temp_dump"): + os.remove(os.path.join("temp_dump", name)) + + filename = f"temp_dump/{op_type}_{providers[0]}_{itype}.onnx" + opts = SessionOptions() + opts.optimized_model_filepath = filename + sess = InferenceSession(model.SerializeToString(), opts, providers=providers) + self.assertTrue(sess is not None) + self.assert_exists(filename) + onx = load(filename) + names = [n.op_type for n in onx.graph.node] + self.assertEqual(expected_names, names) + + sonx = str(onx).replace(" ", "").replace("\n", "|") + sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype + sexp2 = 'op_type:"Cast"|attribute{|name:"to"|i:%d|type:INT|}' % itype + assert sexp in sonx or sexp2 in sonx, f"Unable to find a substring in {sonx!r}" + if providers == ["CPUExecutionProvider"]: + return + + if op_type == "ScatterElements": + data = np.zeros((3, 3), dtype=np.float32) + data[0, 0] = 1 + indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64) + updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=dtype) + else: + data = np.array( + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], + [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], + ], + dtype=np.float32, + ) + indices = np.array([[0], [2]], dtype=np.int64) + updates = np.array( + [ + [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], + ], + dtype=dtype, + ) + opts = SessionOptions() + opts.enable_profiling = True + opts.optimized_model_filepath = filename + sess = InferenceSession(model.SerializeToString(), opts, providers=providers) + got = sess.run(None, {"X": data, "indices": indices, "updates": updates})[0] + self.assertEqual(got.dtype, updates.dtype) + prof = sess.end_profiling() + + with open(prof, "r") as f: # noqa: UP015 + content = f.read() + js = json.loads(content) + + exe_providers = [] + suffixes = ["_kernel_time", "_fence_before", "_fence_after"] + rows = [] + for row in js: + if "args" in row and isinstance(row["args"], dict): + for k, v in row["args"].items(): + row[f"args_{k}"] = v + del row["args"] + name = row["name"] + for suf in suffixes: + if name.endswith(suf): + changed = name[: -len(suf)] + row["op_name"] = changed + break + rows.append(row) + exe_providers.append((row.get("args_provider", None), row.get("args_op_name", None))) + short_list = [(a, b) for a, b in exe_providers if a is not None and b is not None] + self.assertEqual(short_list, [("CUDAExecutionProvider", o) for o in expected_names]) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + @ignore_warnings(DeprecationWarning) + def test_scatterels_cuda(self): + default_value = [ + "Cast", + "ScatterElements", + "Sub", + ] + expected = { + (np.float32, "none"): default_value, + (np.float16, "none"): default_value, + (np.float32, "add"): default_value, + (np.float16, "add"): default_value, + (np.float32, "mul"): default_value, + (np.float16, "mul"): default_value, + (np.float32, "min"): default_value, + (np.float16, "min"): default_value, + (np.float32, "max"): default_value, + (np.float16, "max"): default_value, + } + for opset, dtype, reduction in itertools.product( + [16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"] + ): + with self.subTest(dtype=dtype, reduction=reduction, opset=opset): + self.common_scatter( + opset, + ["CUDAExecutionProvider"], + dtype, + reduction, + expected[dtype, reduction], + ) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + @ignore_warnings(DeprecationWarning) + def test_scatternd_cuda(self): + default_value = [ + "Cast", + "ScatterND", + "Sub", + ] + expected = { + (np.float32, "none"): default_value, + (np.float16, "none"): default_value, + (np.float32, "add"): default_value, + (np.float16, "add"): default_value, + (np.float32, "mul"): default_value, + (np.float16, "mul"): default_value, + (np.float32, "min"): default_value, + (np.float16, "min"): default_value, + (np.float32, "max"): default_value, + (np.float16, "max"): default_value, + } + for opset, dtype, reduction in itertools.product( + [16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"] + ): + with self.subTest(dtype=dtype, reduction=reduction, opset=opset): + self.common_scatter( + opset, + ["CUDAExecutionProvider"], + dtype, + reduction, + expected[dtype, reduction], + ) + + @ignore_warnings(DeprecationWarning) + def test_scatterels_cpu(self): + default_value = [ + "Cast", + "ScatterElements", + "Sub", + ] + expected = { + (np.float32, "none"): default_value, + (np.float16, "none"): default_value, + (np.float32, "add"): default_value, + (np.float16, "add"): default_value, + (np.float32, "mul"): default_value, + (np.float16, "mul"): default_value, + (np.float32, "min"): default_value, + (np.float16, "min"): default_value, + (np.float32, "max"): default_value, + (np.float16, "max"): default_value, + } + for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]): + with self.subTest(dtype=dtype, reduction=reduction, opset=opset): + self.common_scatter( + opset, + ["CPUExecutionProvider"], + dtype, + reduction, + expected[dtype, reduction], + ) + + @ignore_warnings(DeprecationWarning) + def test_scatternd_cpu(self): + default_value = [ + "Cast", + "ScatterND", + "Sub", + ] + expected = { + (np.float32, "none"): default_value, + (np.float16, "none"): default_value, + (np.float32, "add"): default_value, + (np.float16, "add"): default_value, + (np.float32, "mul"): default_value, + (np.float16, "mul"): default_value, + (np.float32, "min"): default_value, + (np.float16, "min"): default_value, + (np.float32, "max"): default_value, + (np.float16, "max"): default_value, + } + for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]): + with self.subTest(dtype=dtype, reduction=reduction, opset=opset): + self.common_scatter( + opset, + ["CPUExecutionProvider"], + dtype, + reduction, + expected[dtype, reduction], + ) + + def _scatternd_standalone_cuda(self, reduction, line): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node( + "ScatterND", + inputs=["data", "indices", "updates"], + outputs=["y"], + reduction=reduction, + ) + ], + "nd", + [ + oh.make_tensor_value_info("data", TensorProto.FLOAT, [None, None, None]), + oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]), + oh.make_tensor_value_info("updates", TensorProto.FLOAT, [None, None, None]), + ], + [oh.make_tensor_value_info("y", TensorProto.FLOAT, [None, None, None])], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + + data = np.full((2, 2, 3), 0.1, dtype=np.float32) + indices = np.array([[line], [1 - line], [line]], dtype=np.int64) + updates = (2 ** (np.arange(18) + 1).astype(np.float32).reshape((3, 2, 3))).astype(np.float32) + + feeds = dict(data=data, indices=indices, updates=updates) + ref = ReferenceEvaluator(model) + expected = ref.run(None, feeds)[0] + + providers = ( + [ + ["CUDAExecutionProvider"], + ["CPUExecutionProvider"], + ] + if has_cuda() + else [["CPUExecutionProvider"]] + ) + for provider in providers: + sess = onnxruntime.InferenceSession(model.SerializeToString(), providers=provider) + got = sess.run(None, feeds)[0] + self.assertEqual(expected.tolist(), got.tolist()) + + def test_scatternd_standalone_cuda(self): + self._scatternd_standalone_cuda("add", 0) + self._scatternd_standalone_cuda("add", 1) + self._scatternd_standalone_cuda("mul", 0) + self._scatternd_standalone_cuda("mul", 1) + self._scatternd_standalone_cuda("min", 0) + self._scatternd_standalone_cuda("min", 1) + self._scatternd_standalone_cuda("max", 0) + self._scatternd_standalone_cuda("max", 1) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py index c425bf956f976..fbf95767b3fdf 100644 --- a/onnxruntime/test/python/quantization/test_subgraph.py +++ b/onnxruntime/test/python/quantization/test_subgraph.py @@ -19,9 +19,13 @@ def test_dynamic_quantization_subgraph(self): with tempfile.TemporaryDirectory() as tmpdir: onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx") quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx") - urllib.request.urlretrieve( - "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path - ) + url = "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx" + try: + urllib.request.urlretrieve(url, onnx_path) + except urllib.request.HTTPError as e: + # The unit test should not fail for this kind of issue. + # TODO: use another way to retrieve the model. + raise unittest.SkipTest(f"Unable to fetch {url!r} due to {e}") # noqa: B904 quantize_dynamic( model_input=onnx_path, @@ -62,3 +66,7 @@ def test_dynamic_quantization_subgraph(self): if attr.type == onnx.AttributeProto.GRAPH: for initializer in attr.g.initializer: self.assertTrue("shared.weight" not in initializer.name) + + +if __name__ == "__main__": + unittest.main(verbosity=2)