Skip to content

Commit

Permalink
cuda 5D cases
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Jan 5, 2024
1 parent ac94b28 commit 7c0ae44
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 55 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/tensor/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ T GridSample<T>::PixelAtGrid3D(const T* image, int64_t d, int64_t h, int64_t w,
//
template <typename T>
Status GridSample<T>::Compute(OpKernelContext* context) const {
const auto* input = context->Input<Tensor>(0);
const auto* X = context->Input<Tensor>(0);
const auto* grid = context->Input<Tensor>(1);
const auto& input_dims = input->Shape();
const auto& input_dims = X->Shape();
const auto& grid_dims = grid->Shape();

int64_t data_dims = input_dims.NumDimensions() - 2;
Expand Down Expand Up @@ -209,7 +209,7 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
concurrency::ThreadPool::TrySimpleParallelFor(
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
[&](std::ptrdiff_t c) {
const T* X_data = input->Data<T>() + (n * C + c) * (H_in * W_in);
const T* X_data = X->Data<T>() + (n * C + c) * (H_in * W_in);
T* Y_data = Y.MutableData<T>() + (n * C + c) * (H_out * W_out);

for (int64_t oy = 0; oy < H_out; oy++) {
Expand Down Expand Up @@ -298,7 +298,7 @@ Status GridSample<T>::Compute(OpKernelContext* context) const {
concurrency::ThreadPool::TrySimpleParallelFor(
tp, onnxruntime::narrow<std::ptrdiff_t>(C),
[&](std::ptrdiff_t c) {
const T* X_data = input->Data<T>() + (n * C + c) * (D_in * H_in * W_in);
const T* X_data = X->Data<T>() + (n * C + c) * (D_in * H_in * W_in);
T* Y_data = Y.MutableData<T>() + (n * C + c) * (D_out * H_out * W_out);

for (int64_t oz = 0; oz < D_out; oz++) {
Expand Down
95 changes: 65 additions & 30 deletions onnxruntime/core/providers/cuda/tensor/grid_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,76 @@ GridSample<T>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {

template <typename T>
Status GridSample<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input<Tensor>(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input<Tensor>(1);
const auto& dims_grid = Grid->Shape().GetDims();
const auto* X = context->Input<Tensor>(0);
const auto* grid = context->Input<Tensor>(1);
const auto& input_dims = X->Shape();
const auto& grid_dims = grid->Shape();

if (dims_input.size() != 4 || dims_grid.size() != 4) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
}
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
int64_t data_dims = input_dims.NumDimensions() - 2;
ORT_ENFORCE(static_cast<int64_t>(grid_dims.NumDimensions()) == data_dims + 2,
"grid dimensions must be ", data_dims + 2, "for input dimension of ", data_dims);

ORT_ENFORCE(grid_dims[grid_dims.NumDimensions() - 1] == data_dims,
"Last dimension of grid: ", grid_dims[grid_dims.NumDimensions() - 1], ", expect ", data_dims);

ORT_ENFORCE(input_dims.NumDimensions() == 4 || input_dims.NumDimensions() == 5, "Only 4-D or 5-D tensor is supported");

Check warning on line 50 in onnxruntime/core/providers/cuda/tensor/grid_sample.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample.cc#L50

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample.cc:50:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

auto N = input_dims[0];
auto C = input_dims[1];
ORT_ENFORCE(grid_dims[0] == N, "Grid batch size ", grid_dims[0], " does not match input batch size ", N);

TensorShapeVector dims_output(4);
dims_output[0] = dims_input[0];
dims_output[1] = dims_input[1];
dims_output[2] = dims_grid[1];
dims_output[3] = dims_grid[2];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
if (input_dims.NumDimensions() == 5) {
ORT_ENFORCE(mode_i_ != 3, "Only support GridSample Cubic mode in 4-D cases.");
}

typedef typename ToCudaType<T>::MappedType CudaT;
CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
GridSampleImpl<CudaT>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(Grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
dims_input.data(),
dims_grid[1],
dims_grid[2],
Y_data);
if (data_dims == 2) {
auto H_out = grid_dims[1];
auto W_out = grid_dims[2];
TensorShape Y_shape = {N, C, H_out, W_out};
Tensor* Y = context->Output(0, Y_shape);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
}

CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
GridSampleImpl<CudaT>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
input_dims.GetDims().data(),
grid_dims[1],
grid_dims[2],
Y_data);
} else if (data_dims == 3) {
auto D_out = grid_dims[1];
auto H_out = grid_dims[2];
auto W_out = grid_dims[3];
TensorShape Y_shape = {N, C, D_out, H_out, W_out};
Tensor* Y = context->Output(0, Y_shape);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
}

CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
GridSampleImpl3D<CudaT>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
input_dims.GetDims().data(),
grid_dims[1],
grid_dims[2],
grid_dims[3],
Y_data);
}
return Status::OK();
}
} // namespace cuda
Expand Down
144 changes: 144 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_
return pixel;
}

template <typename T>
__device__ T PixelAtGrid3D(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t z, int64_t y, int64_t x,
int64_t padding_mode, int64_t N, int64_t C, int64_t D, int64_t H, int64_t W, float border[6]) {

Check warning on line 74 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L74

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:74:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
T pixel = 0.0f;
if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H && z >= 0 && z < D) {
pixel = input_data[bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x];
}
} else if (padding_mode == 1) { // border
x = std::clamp<int64_t>(x, 0, W - 1);
y = std::clamp<int64_t>(y, 0, H - 1);
z = std::clamp<int64_t>(z, 0, D - 1);
pixel = input_data[bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x];
} else { // Reflection
x = (int64_t)GsReflect<T>(x, border[0], border[3]);
y = (int64_t)GsReflect<T>(y, border[1], border[4]);
z = (int64_t)GsReflect<T>(z, border[2], border[5]);
pixel = input_data[bIdx * C * D * H * W + cIdx * D * H * W + z * H * W + y * W + x];
}
return pixel;
}

__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
{
float cubic_alpha = -0.75f;
Expand Down Expand Up @@ -216,5 +238,127 @@ void GridSampleImpl(

SPECIALIZED_IMPL(float)

template <typename T>
__global__ void _GridSampleKernel3D(
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t N,
const int64_t C,
const int64_t D_in,
const int64_t H_in,
const int64_t W_in,
const int64_t D_out,
const int64_t H_out,
const int64_t W_out,
T* output_data) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * D_out * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
int BIdx = idx / (C * D_out * H_out * W_out);
int tmpBCnt = BIdx * (C * D_out * H_out * W_out);

int cIdx = (idx - tmpBCnt) / (D_out * H_out * W_out);
int tmpCCnt = tmpBCnt + cIdx * (D_out * H_out * W_out);

int zIdx = (idx - tmpCCnt) / (H_out * W_out);
int tmpDCnt = tmpCCnt + zIdx * H_out * W_out;

int yIdx = (idx - tmpDCnt) / W_out;
int tmpHCnt = tmpDCnt + yIdx * W_out;

int xIdx = (idx - tmpHCnt);

int grid_idx = BIdx * D_out * H_out * W_out + zIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 3 + 0];
T grid_Y = grid_data[grid_idx * 3 + 1];
T grid_Z = grid_data[grid_idx * 3 + 2];
int outIdx = idx;

float x_min = -0.5f;
float x_max = W_in - 0.5f;
float y_min = -0.5f;
float y_max = H_in - 0.5f;
float z_min = -0.5f;
float z_max = D_in - 0.5f;

if (align_corners) {
x_min = 0.0f;
x_max = W_in - 1.0;
y_min = 0.0f;
y_max = H_in - 1.0f;
z_min = 0.0f;
z_max = D_in - 1.0f;
}
float border[] = {x_min, y_min, z_min, x_max, y_max, z_max}; // l-t-n-r-b-f

T grid_x_imgSpace = GsDenormalize(grid_X, W_in, align_corners == 1);
T grid_y_imgSpace = GsDenormalize(grid_Y, H_in, align_corners == 1);
T grid_z_imgSpace = GsDenormalize(grid_Z, D_in, align_corners == 1);

if (mode == 0) { // trilinear
int x1 = floor(grid_x_imgSpace);
int y1 = floor(grid_y_imgSpace);
int z1 = floor(grid_z_imgSpace);
int x2 = x1 + 1;
int y2 = y1 + 1;
int z2 = z1 + 1;
T dx2 = static_cast<T>(x2) - grid_x_imgSpace;
T dx1 = grid_x_imgSpace - static_cast<T>(x1);
T dy2 = static_cast<T>(y2) - grid_y_imgSpace;
T dy1 = grid_y_imgSpace - static_cast<T>(y1);
T dz2 = static_cast<T>(z2) - grid_z_imgSpace;
T dz1 = grid_z_imgSpace - static_cast<T>(z1);

T p111 = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T p112 = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T p121 = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T p122 = PixelAtGrid3D(input_data, BIdx, cIdx, z1, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T Y_gridpoint_z1 = dy2 * (dx2 * p111 + dx1 * p112) + dy1 * (dx2 * p121 + dx1 * p122);

T p211 = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T p212 = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y1, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T p221 = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x1, padding_mode, N, C, D_in, H_in, W_in, border);
T p222 = PixelAtGrid3D(input_data, BIdx, cIdx, z2, y2, x2, padding_mode, N, C, D_in, H_in, W_in, border);
T Y_gridpoint_z2 = dy2 * (dx2 * p211 + dx1 * p212) + dy1 * (dx2 * p221 + dx1 * p222);
output_data[outIdx] = dz2 * Y_gridpoint_z1 + dz1 * Y_gridpoint_z2;
return;
}
if (mode == 1) { // nearest
T x = static_cast<T>(std::nearbyint(static_cast<T>(grid_x_imgSpace)));
T y = static_cast<T>(std::nearbyint(static_cast<T>(grid_y_imgSpace)));
T z = static_cast<T>(std::nearbyint(static_cast<T>(grid_z_imgSpace)));
output_data[outIdx] = PixelAtGrid3D(input_data, BIdx, cIdx, static_cast<int64_t>(z), static_cast<int64_t>(y), static_cast<int64_t>(x),

Check warning on line 332 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L332

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:332:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
padding_mode, N, C, D_in, H_in, W_in, border);
return;
}
}

template <typename T>
void GridSampleImpl3D(
cudaStream_t stream,
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t dims[5],
const int64_t D_out,
const int64_t H_out,
const int64_t W_out,
T* output_data) {
int blocksPerGrid = (int)(ceil(static_cast<T>(dims[0] * dims[1] * D_out * H_out * W_out) / GridDim::maxThreadsPerBlock));

Check warning on line 351 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L351

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:351:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 351 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L351

Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:351:  Using C-style cast.  Use static_cast<int>(...) instead  [readability/casting] [4]
_GridSampleKernel3D<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], dims[4], D_out, H_out, W_out, output_data);

Check warning on line 353 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L353

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:353:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

#define SPECIALIZED_IMPL_3D(T) \
template void GridSampleImpl3D<T>(cudaStream_t stream, const T* input_data, const T* grid_data, \
const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
const int64_t[5], const int64_t D_out, const int64_t H_out, const int64_t W_out, T* output_data);

Check warning on line 359 in onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu#L359

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/tensor/grid_sample_impl.cu:359:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

SPECIALIZED_IMPL_3D(float)

} // namespace cuda
} // namespace onnxruntime
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/grid_sample_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,18 @@ void GridSampleImpl(
const int64_t W_out,
T* output_data);

template <typename T>
void GridSampleImpl3D(
cudaStream_t stream,
const T* input_data,
const T* grid_data,
const int64_t mode,
const int64_t padding_mode,
const int64_t align_corners,
const int64_t dims_input[5],
const int64_t D_out,
const int64_t H_out,
const int64_t W_out,
T* output_data);
} // namespace cuda
} // namespace onnxruntime
Loading

0 comments on commit 7c0ae44

Please sign in to comment.