Skip to content

Commit

Permalink
Add implementation for ScatterND (microsoft#19540)
Browse files Browse the repository at this point in the history
### Description
onnxruntime switches to CPU for ScatterND after opset 13. This extends
the implementation of higher opsets.
  • Loading branch information
xadupre authored and Ted Themistokleous committed May 7, 2024
1 parent cd11058 commit d27a29f
Show file tree
Hide file tree
Showing 15 changed files with 868 additions and 41 deletions.
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ Do not modify directly.*
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|SequenceAt|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
Expand Down Expand Up @@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);

// Opset 17
Expand All @@ -1312,6 +1313,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
Expand Down Expand Up @@ -2071,7 +2073,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad)>,
Expand Down Expand Up @@ -2202,6 +2204,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,

// Opset 17
Expand All @@ -2225,6 +2228,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
} else if (reduction_ == "max") {
args.operation = GatherScatterElementsArgs::Operation::MAX;
} else {
ORT_THROW("Unsupported reduction type");
ORT_THROW("Unsupported reduction type for ScatterElements.");
}

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
Expand Down
148 changes: 133 additions & 15 deletions onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/cuda/tensor/scatter_nd.h"
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
#include "core/providers/cuda/tensor/scatter_nd_common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cpu/tensor/utils.h"

Expand All @@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterND);
ScatterNDDisjointAndNoReduction);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
kOnnxDomain,
13, 15,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterNDWithAtomicReduction);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
kOnnxDomain,
16, 17,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterNDWithAtomicReduction);

ONNX_OPERATOR_KERNEL_EX(ScatterND,
kOnnxDomain,
13,
18,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterND);
ScatterNDWithAtomicReduction);

Status ScatterND::ComputeInternal(OpKernelContext* context) const {
static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape,
ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
CudaKernel::CudaAsyncBuffer<int64_t>& element_counts_and_input_dims_gpu,
onnxruntime::OpKernelContext* context) {
TensorPitches input_strides(input_shape);

if (last_index_dimension < 6) {
element_counts_and_input_dims.gpu_ptr = nullptr;
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims.stack_ptr[i] = input_strides[i];
element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i];
}
} else {
element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2);
memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2);
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i];
element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i];
}
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr();
}
return Status::OK();
}

Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
const auto* updates_tensor = context->Input<Tensor>(2);
Expand All @@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
const void* input_data = input_tensor->DataRaw();
void* output_data = output_tensor->MutableDataRaw();

size_t element_size = input_tensor->DataType()->Size();

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
CUDA_RETURN_IF_ERROR(
Expand All @@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
}

auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
size_t element_size = input_tensor->DataType()->Size();

// We need element counts for each dimension and the input dim value for each dimension
// for the range [0, last_index_dimension).
// To avoid multiple GPU data transfers, we combine this into one array and send it through
TensorPitches input_strides(input_shape);
std::vector<int64_t> element_counts_and_input_dims(last_index_dimension * 2, 0LL);
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims[i] = input_strides[i];
element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
}
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
element_counts_and_input_dims,
element_counts_and_input_dims_gpu,
context));

ORT_RETURN_IF_ERROR(ScatterNDImpl(
Stream(context),
Expand All @@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims_gpu.GpuPtr(),
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension)));

return Status::OK();
}

Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
const auto* updates_tensor = context->Input<Tensor>(2);

const auto& input_shape = input_tensor->Shape();
const auto& indices_shape = indices_tensor->Shape();
const auto& updates_shape = updates_tensor->Shape();

// Validate input shapes
ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape));

auto* output_tensor = context->Output(0, input_shape);

const void* input_data = input_tensor->DataRaw();
void* output_data = output_tensor->MutableDataRaw();

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will
// be faster than invoking cudaMemcpy ?
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
}

// Bail out early
if (indices_shape.Size() == 0) {
return Status::OK();
}

auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
element_counts_and_input_dims,
element_counts_and_input_dims_gpu,
context));

switch (reduction_) {
case ScatterNDReduction::None: {
size_t element_size = input_tensor->DataType()->Size();
ORT_RETURN_IF_ERROR(ScatterNDImpl(
Stream(context),
output_data,
element_size,
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension)));
} break;
case ScatterNDReduction::Add:
case ScatterNDReduction::Min:
case ScatterNDReduction::Max:
case ScatterNDReduction::Mul: {
auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType();
ORT_RETURN_IF_ERROR(ScatterNDImplReduction(
Stream(context),
output_data,
element_type,
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension),
reduction_));
} break;
default:
ORT_THROW("ScatterND not supported for other reduction than Add, None.");
break;
}

return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit d27a29f

Please sign in to comment.