From 3d0db47928becaa563e42098fbc948b37c57eb6c Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 26 Oct 2023 10:39:05 -0700 Subject: [PATCH] Expose Expand --- .../core/providers/cuda/tensor/expand.cc | 80 +++++++++++++++++++ .../core/providers/cuda/tensor/expand.h | 13 +++ 2 files changed, 93 insertions(+) diff --git a/onnxruntime/core/providers/cuda/tensor/expand.cc b/onnxruntime/core/providers/cuda/tensor/expand.cc index e9634df205842..368c167f58641 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.cc +++ b/onnxruntime/core/providers/cuda/tensor/expand.cc @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const { input_strides); } +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor) { + + TensorShape output_shape = output_tensor->Shape(); + +#ifdef ENABLE_STRIDED_TENSORS + // Strided output. + if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) { + gsl::span input_strides = input_data_tensor->Strides(); + TensorShapeVector output_strides = + ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape); + output_tensor->SetShapeAndStrides(output_shape, output_strides); + return Status::OK(); + } +#endif + + auto output_dims = output_shape.AsShapeVector(); + auto input_dims = input_data_tensor->Shape().AsShapeVector(); + + CalcEffectiveDims(input_dims, output_dims); + int rank = gsl::narrow_cast(output_dims.size()); + + TensorPitches original_input_strides(input_dims); + TensorPitches original_output_strides(output_dims); + + TArray input_strides(rank); + for (auto i = 0; i < rank; i++) { + input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i]; + } + + TArray output_strides(rank); + for (auto i = 0; i < rank; i++) { + output_strides[i] = fast_divmod(static_cast(original_output_strides[i])); + } + + return ExpandImpl( + cuda_kernel->Stream(ctx), + input_data_tensor->DataType()->Size(), + gsl::narrow_cast(output_shape.Size()), + gsl::narrow_cast(input_data_tensor->Shape().Size()), + input_data_tensor->DataRaw(), + output_tensor->MutableDataRaw(), + output_strides, + input_strides); +} + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor) { + // new shape to be expanded to + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + + ORT_ENFORCE( + ComputeOutputShape( + cuda_kernel->Node().Name(), + input_data_tensor->Shape(), + output_dims, output_shape).IsOK()); + + // Pre-allocate output. + AllocatorPtr alloc; + ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK()); + auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc); + + // Only assign output values when output tensor is non-empty + // because empty tensor doesn't own any data. + if (output_shape.Size() > 0) { + ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK()); + } + + return output_tensor; +} + #ifdef ENABLE_STRIDED_TENSORS #define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0) #else diff --git a/onnxruntime/core/providers/cuda/tensor/expand.h b/onnxruntime/core/providers/cuda/tensor/expand.h index 4cf4c14e61058..a0b12790017f6 100644 --- a/onnxruntime/core/providers/cuda/tensor/expand.h +++ b/onnxruntime/core/providers/cuda/tensor/expand.h @@ -20,5 +20,18 @@ Status ComputeOutputShape( const TensorShape& rhs_shape, TensorShape& out_shape); +Status FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* /*input_shape_tensor*/, + Tensor* output_tensor); + +std::unique_ptr FuncExpand( + const CudaKernel* cuda_kernel, + OpKernelContext* ctx, + const Tensor* input_data_tensor, + const Tensor* input_shape_tensor); + } // namespace cuda } // namespace onnxruntime