Skip to content

Commit

Permalink
Expose Expand
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Oct 26, 2023
1 parent 47b3312 commit 3d0db47
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
80 changes: 80 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,86 @@ Status Expand::ComputeInternal(OpKernelContext* ctx) const {
input_strides);
}

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor) {

TensorShape output_shape = output_tensor->Shape();

#ifdef ENABLE_STRIDED_TENSORS
// Strided output.
if (input_data_tensor->DataRaw() == output_tensor->DataRaw()) {
gsl::span<const int64_t> input_strides = input_data_tensor->Strides();
TensorShapeVector output_strides =
ComputeOutputStrides(input_data_tensor->Shape(), input_strides, output_shape);
output_tensor->SetShapeAndStrides(output_shape, output_strides);
return Status::OK();
}
#endif

auto output_dims = output_shape.AsShapeVector();
auto input_dims = input_data_tensor->Shape().AsShapeVector();

CalcEffectiveDims(input_dims, output_dims);
int rank = gsl::narrow_cast<int>(output_dims.size());

TensorPitches original_input_strides(input_dims);
TensorPitches original_output_strides(output_dims);

TArray<int64_t> input_strides(rank);
for (auto i = 0; i < rank; i++) {
input_strides[i] = input_dims[i] == 1 ? 0 : original_input_strides[i];
}

TArray<fast_divmod> output_strides(rank);
for (auto i = 0; i < rank; i++) {
output_strides[i] = fast_divmod(static_cast<int>(original_output_strides[i]));
}

return ExpandImpl(
cuda_kernel->Stream(ctx),
input_data_tensor->DataType()->Size(),
gsl::narrow_cast<int>(output_shape.Size()),
gsl::narrow_cast<int>(input_data_tensor->Shape().Size()),
input_data_tensor->DataRaw(),
output_tensor->MutableDataRaw(),
output_strides,
input_strides);
}

std::unique_ptr<Tensor> FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor) {
// new shape to be expanded to
const auto* p_shape = input_shape_tensor->Data<int64_t>();
TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()};
TensorShape output_shape(output_dims);

ORT_ENFORCE(
ComputeOutputShape(
cuda_kernel->Node().Name(),
input_data_tensor->Shape(),
output_dims, output_shape).IsOK());

// Pre-allocate output.
AllocatorPtr alloc;
ORT_ENFORCE(ctx->GetTempSpaceAllocator(&alloc).IsOK());
auto output_tensor = Tensor::Create(input_data_tensor->DataType(), output_shape, alloc);

// Only assign output values when output tensor is non-empty
// because empty tensor doesn't own any data.
if (output_shape.Size() > 0) {
ORT_ENFORCE(FuncExpand(cuda_kernel, ctx, input_data_tensor, input_shape_tensor, output_tensor.get()).IsOK());
}

return output_tensor;
}

#ifdef ENABLE_STRIDED_TENSORS
#define CREATE_EXPAND_KERNEL_DEF (*KernelDefBuilder::Create()).MayStridedOutput(0, 0)
#else
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,18 @@ Status ComputeOutputShape(
const TensorShape& rhs_shape,
TensorShape& out_shape);

Status FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* /*input_shape_tensor*/,
Tensor* output_tensor);

std::unique_ptr<Tensor> FuncExpand(
const CudaKernel* cuda_kernel,
OpKernelContext* ctx,
const Tensor* input_data_tensor,
const Tensor* input_shape_tensor);

} // namespace cuda
} // namespace onnxruntime

0 comments on commit 3d0db47

Please sign in to comment.