Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation for ScatterND #19540

Merged
merged 46 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
e59597a
add test to check scatter
xadupre Feb 16, 2024
7e7daad
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 16, 2024
8a3417c
add scatternd
xadupre Feb 16, 2024
0dc7948
add scatter nd on cuda
xadupre Feb 16, 2024
f3c9532
lint
xadupre Feb 16, 2024
54bebb5
lint
xadupre Feb 16, 2024
9b65acf
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 21, 2024
01866c2
fix lint, disable a test on cpu
xadupre Feb 21, 2024
89ebc2a
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 22, 2024
880d0b0
fix the condition to disable a test
xadupre Feb 22, 2024
7a5cf58
fix implementation of ScatterND
xadupre Feb 23, 2024
195c794
better comment
xadupre Feb 23, 2024
faa285c
add other reduction
xadupre Feb 23, 2024
b24ef8b
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 23, 2024
aa294bf
fix compilation
xadupre Feb 23, 2024
3ed53b2
fix cu
xadupre Feb 27, 2024
1ed1625
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 27, 2024
e2f7bcc
better implementation, lint
xadupre Feb 28, 2024
fff3220
update unit test
xadupre Feb 29, 2024
59a8461
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Feb 29, 2024
3e16bc6
lint and fix unit test
xadupre Feb 29, 2024
fb01f14
fix compiling issue one windows
xadupre Feb 29, 2024
508cbec
fix compilation
xadupre Feb 29, 2024
89d2e5e
fix merge conflicts
xadupre Mar 1, 2024
76fbf8b
fix md doc
xadupre Mar 4, 2024
0498a4f
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Mar 4, 2024
6fd84ff
exclude openvino
xadupre Mar 5, 2024
f61f7fd
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Mar 5, 2024
0f85b22
fix merge conflicts
xadupre Mar 15, 2024
4d03e68
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Mar 19, 2024
b58b811
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 9, 2024
85c23c1
remove define
xadupre Apr 9, 2024
c93ff81
fix misspelling
xadupre Apr 9, 2024
bbdce4a
lint
xadupre Apr 9, 2024
53837ca
avoid allocating memory on gpu
xadupre Apr 15, 2024
a8288da
fix merge conflict
xadupre Apr 15, 2024
7284858
fix style
xadupre Apr 15, 2024
25360f0
remove unnecessary to fetch an attribute value
xadupre Apr 22, 2024
2fd9628
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 22, 2024
ac6c36e
do not fail for a connectity error
xadupre Apr 22, 2024
3bdfbfd
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 22, 2024
9397cea
lint
xadupre Apr 22, 2024
dde60a0
fix lint
xadupre Apr 22, 2024
6bbf978
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 23, 2024
1c0dceb
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 23, 2024
7eda59e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into f16
xadupre Apr 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,9 @@ Do not modify directly.*
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|SequenceAt|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
Expand Down Expand Up @@ -1295,6 +1295,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);

// Opset 17
Expand All @@ -1312,6 +1313,7 @@
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
Expand Down Expand Up @@ -2071,7 +2073,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,

Check warning on line 2076 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2076: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad)>,
Expand Down Expand Up @@ -2202,6 +2204,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,

Check warning on line 2207 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2207: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,

// Opset 17
Expand All @@ -2225,6 +2228,7 @@
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
} else if (reduction_ == "max") {
args.operation = GatherScatterElementsArgs::Operation::MAX;
} else {
ORT_THROW("Unsupported reduction type");
ORT_THROW("Unsupported reduction type for ScatterElements.");
}

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
Expand Down
148 changes: 133 additions & 15 deletions onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/cuda/tensor/scatter_nd.h"
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
#include "core/providers/cuda/tensor/scatter_nd_common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cpu/tensor/utils.h"

Expand All @@ -16,18 +17,61 @@
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterND);
ScatterNDDisjointAndNoReduction);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
kOnnxDomain,
13, 15,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterNDWithAtomicReduction);

ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
kOnnxDomain,
16, 17,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterNDWithAtomicReduction);

ONNX_OPERATOR_KERNEL_EX(ScatterND,
kOnnxDomain,
13,
18,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
ScatterND);
ScatterNDWithAtomicReduction);

Status ScatterND::ComputeInternal(OpKernelContext* context) const {
static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape,
ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,

Check warning on line 50 in onnxruntime/core/providers/cuda/tensor/scatter_nd.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/tensor/scatter_nd.cc:50: Lines should be <= 120 characters long [whitespace/line_length] [2]
CudaKernel::CudaAsyncBuffer<int64_t>& element_counts_and_input_dims_gpu,

Check warning on line 51 in onnxruntime/core/providers/cuda/tensor/scatter_nd.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/tensor/scatter_nd.cc:51: Lines should be <= 120 characters long [whitespace/line_length] [2]
onnxruntime::OpKernelContext* context) {
TensorPitches input_strides(input_shape);

if (last_index_dimension < 6) {
element_counts_and_input_dims.gpu_ptr = nullptr;
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims.stack_ptr[i] = input_strides[i];
element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i];
}
} else {
element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2);
memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2);
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i];
element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i];
}
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr();
}
return Status::OK();
}

Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const {
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
const auto* updates_tensor = context->Input<Tensor>(2);
Expand All @@ -44,8 +88,6 @@
const void* input_data = input_tensor->DataRaw();
void* output_data = output_tensor->MutableDataRaw();

size_t element_size = input_tensor->DataType()->Size();

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
CUDA_RETURN_IF_ERROR(
Expand All @@ -58,18 +100,17 @@
}

auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
size_t element_size = input_tensor->DataType()->Size();

// We need element counts for each dimension and the input dim value for each dimension
// for the range [0, last_index_dimension).
// To avoid multiple GPU data transfers, we combine this into one array and send it through
TensorPitches input_strides(input_shape);
std::vector<int64_t> element_counts_and_input_dims(last_index_dimension * 2, 0LL);
for (int64_t i = 0; i < last_index_dimension; ++i) {
element_counts_and_input_dims[i] = input_strides[i];
element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
}
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
element_counts_and_input_dims,
element_counts_and_input_dims_gpu,
context));

ORT_RETURN_IF_ERROR(ScatterNDImpl(
Stream(context),
Expand All @@ -78,12 +119,89 @@
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims_gpu.GpuPtr(),
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension)));

return Status::OK();
}

Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const {
xadupre marked this conversation as resolved.
Show resolved Hide resolved
const auto* input_tensor = context->Input<Tensor>(0);
const auto* indices_tensor = context->Input<Tensor>(1);
const auto* updates_tensor = context->Input<Tensor>(2);

const auto& input_shape = input_tensor->Shape();
const auto& indices_shape = indices_tensor->Shape();
const auto& updates_shape = updates_tensor->Shape();

// Validate input shapes
ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape));

auto* output_tensor = context->Output(0, input_shape);

const void* input_data = input_tensor->DataRaw();
void* output_data = output_tensor->MutableDataRaw();

if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will

Check warning on line 147 in onnxruntime/core/providers/cuda/tensor/scatter_nd.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/cuda/tensor/scatter_nd.cc:147: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// be faster than invoking cudaMemcpy ?
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream(context)));
}

// Bail out early
if (indices_shape.Size() == 0) {
return Status::OK();
}

auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
element_counts_and_input_dims,
element_counts_and_input_dims_gpu,
context));

switch (reduction_) {
case ScatterNDReduction::None: {
size_t element_size = input_tensor->DataType()->Size();
ORT_RETURN_IF_ERROR(ScatterNDImpl(
Stream(context),
output_data,
element_size,
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension)));
} break;
case ScatterNDReduction::Add:
case ScatterNDReduction::Min:
case ScatterNDReduction::Max:
case ScatterNDReduction::Mul: {
auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType();
ORT_RETURN_IF_ERROR(ScatterNDImplReduction(
Stream(context),
output_data,
element_type,
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension),
reduction_));
} break;
default:
ORT_THROW("ScatterND not supported for other reduction than Add, None.");
break;
}

return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
Loading
Loading