Skip to content

Commit

Permalink
Fix cuDNN v9 build by replacing removed cuDNN v6 RNN API usage by cuD…
Browse files Browse the repository at this point in the history
…NN v8 RNN API and reenable RNN tests for CUDA EP (#19419)

Replace deprecated cuDNN RNN based API by cuDNN v8 RNN API and re-enable
RNN tests for the CUDA EP.

### Motivation and Context
The deprecated cuDNN RNN API might vanish soon and in addition for the
current CUDA EP RNN implementation all RNN tests are disabled due to
failures. With this change the deprecated API has been removed and the
new updated implemented doesn't fail the tests anymore.
  • Loading branch information
mtavenrath authored Feb 23, 2024
1 parent f430600 commit efbe2b8
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 302 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cudnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class CudnnTensor final {

operator cudnnTensorDescriptor_t() const { return tensor_; }

Status CreateTensorIfNeeded();

template <typename T>
static cudnnDataType_t GetDataType();

private:
Status CreateTensorIfNeeded();

cudnnTensorDescriptor_t tensor_;
};

Expand Down
350 changes: 182 additions & 168 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc

Large diffs are not rendered by default.

55 changes: 28 additions & 27 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,26 +38,28 @@ class CudnnRNN {
}
}

Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers,
Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers,
cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) {
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) {
if (!cudnn_rnn_desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));

CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle,
cudnn_rnn_desc_,
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_,
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
rnn_mode,
has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS,
cudnn_direction_model,
CUDNN_LINEAR_INPUT,
dataType,
dataType,
dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH,
gsl::narrow_cast<int>(input_size),
gsl::narrow_cast<int>(hidden_size),
gsl::narrow_cast<int>(proj_size), // projected size
num_layers,
cudnn_dropout_desc,
CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation
cudnn_direction_model,
rnn_mode,
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
dataType));

if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) {
cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH);
}
// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
CUDNN_RNN_PADDED_IO_ENABLED));

return Status::OK();
}
Expand Down Expand Up @@ -119,32 +121,30 @@ class CudnnRnnBase : public CudaKernel {
private:
Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle,
const cudnnRNNDescriptor_t rnn_desc,
const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc,
size_t w_data_size,
void* w_data,
const T* W_data,
const T* R_data,
const T* B_data,
cudaStream_t cuda_stream) const;

Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
size_t& target_w_data_size_in_bytes,
IAllocatorUniquePtr<void>& target_w_data,
CudnnFilterDescriptor& target_w_desc,
CudnnRNN& rnn_desc,
onnxruntime::Stream* ort_stream) const;

void SetWeightBias(const cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnn_desc,
const int pseudo_layer,
const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc,
const cudnnFilterDescriptor_t filter_desc,
const void* w_data,
const int lin_layer_id,
const T* pos,
int& offset,
bool is_matrix,
cudaStream_t cuda_stream) const;
Status SetWeightBias(const cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnn_desc,
const int pseudo_layer,
size_t w_data_size,
const void* w_data,
const int lin_layer_id,
const T* pos,
int& offset,
bool is_matrix,
cudaStream_t cuda_stream) const;

void SetZeroSequences(const int64_t zero_seq_index_cache_size,
const std::vector<int32_t> zero_seq_index_cache,
Expand All @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel {
cudnnRNNMode_t rnn_mode_;
// w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input
CudnnFilterDescriptor w_desc_cache_;
size_t w_data_cache_size_in_bytes_;
IAllocatorUniquePtr<void> w_data_cache_;
bool weight_cached_;
int64_t layout_;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/rnn/rnn.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "rnn.h"

#include "core/providers/shared_library/provider_api.h"
#include "rnn_impl.h"
#include "core/providers/cuda/cudnn_common.h"

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/rnn/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "cudnn_rnn_base.h"

#include "core/providers/cuda/cuda_common.h"
#include <cudnn.h>

Expand Down
91 changes: 20 additions & 71 deletions onnxruntime/core/providers/cuda/rnn/rnn_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,32 @@ namespace onnxruntime {
namespace cuda {

template <typename T>
__global__ void _ReverseBySequenceKernel(const int32_t seq_length,
__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length,
const int32_t* seq_lengths,
const int32_t block_size,
const fast_divmod div_batch_block,
const fast_divmod div_input_or_hidden_size,
const T* data,
T* reversed_data,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
int seq_id, offset;
div_batch_block.divmod(id, seq_id, offset);
int org_id = (seq_length - seq_id - 1) * block_size + offset;
reversed_data[id] = data[org_id];
int batch, batch_offset;
div_input_or_hidden_size.divmod(offset, batch, batch_offset);
int seq_id_org = seq_lengths[batch] - seq_id - 1;
if (seq_id_org >= 0) {
int org_id = seq_id_org * block_size + offset;
reversed_data[id] = data[org_id];
} else {
reversed_data[id] = T{};
}
}

template <typename T>
void ReverseBySequence(cudaStream_t stream,
const int32_t seq_length,
const int32_t max_seq_length,
const int32_t *seq_lengths,
const int32_t batch_size,
const int32_t input_or_hidden_size,
const T* data,
Expand All @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream,
// kerneral
int32_t block_size = batch_size * input_or_hidden_size;
fast_divmod div_batch_block(block_size);
fast_divmod div_input_or_hidden_size(input_or_hidden_size);
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_ReverseBySequenceKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N);
max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N);
}

template <typename T>
Expand Down Expand Up @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream,
data, reordered_data, (CUDA_LONG)N);
}

template <typename T>
__global__ void _RnnMaskKernel(const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
const fast_divmod div_seq_block,
const fast_divmod div_dir_block,
const fast_divmod div_batch_block,
T* y_output_data,
T* y_h_output_data,
const CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

int seq_id, direction_id, batch_id, offset;
div_seq_block.divmod(id, seq_id, offset);
div_dir_block.divmod(offset, direction_id, offset);
div_batch_block.divmod(offset, batch_id, offset);
int32_t batch_seq_length = sequence_lens[batch_id];

if (batch_id >= batch_size || batch_seq_length == seq_length) {
return;
}

if (seq_id >= batch_seq_length) {
y_output_data[id] = 0;
return;
}

if ((y_h_output_data != nullptr) &&
((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) {
int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset;
y_h_output_data[hy_idx] = y_output_data[id];
}
}

template <typename T>
void RnnMaskImpl(cudaStream_t stream,
const int32_t num_directions,
const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
T* y_output_data,
T* y_h_output_data,
const size_t N) {
fast_divmod div_seq_block(batch_size * hidden_size * num_directions);
fast_divmod div_dir_block(batch_size * hidden_size);
fast_divmod div_batch_block(hidden_size);
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_RnnMaskKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(
seq_length, batch_size, hidden_size, sequence_lens, div_seq_block,
div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N);
}

template <typename T>
__global__ void _MaskZeroSequences(const int32_t hidden_size,
T* y_output_data,
Expand Down Expand Up @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream,
}

#define SPECIALIZED_RNN_IMPL(T) \
template void RnnMaskImpl<T>(cudaStream_t stream, \
const int32_t num_directions, \
const int32_t seq_length, \
const int32_t batch_size, \
const int32_t hidden_size, \
const int32_t* sequence_lens, \
T* y_output_data, \
T* y_h_output_data, \
const size_t N); \
template void ReverseBySequence<T>(cudaStream_t stream, \
const int32_t seq_length, \
template void ReverseBySequence<T>(cudaStream_t stream, \
const int32_t max_seq_length, \
const int32_t* seq_lengths, \
const int32_t batch_size, \
const int32_t hidden_size, \
const T* data, \
Expand All @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream,
const T* data, \
T* reordered_data, \
const size_t N); \
template void MaskZeroSequences<T>(cudaStream_t stream, \
template void MaskZeroSequences<T>(cudaStream_t stream, \
const int32_t hidden_size, \
T* y_output_data, \
T* y_h_output_data, \
Expand Down
14 changes: 2 additions & 12 deletions onnxruntime/core/providers/cuda/rnn/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace cuda {

template <typename T>
void ReverseBySequence(cudaStream_t stream,
const int32_t seq_length,
const int32_t max_seq_length,
const int32_t* seq_lengths,
const int32_t batch_size,
const int32_t input_or_hidden_size,
const T* data,
Expand All @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream,
T* reordered_data,
const size_t N);

template <typename T>
void RnnMaskImpl(cudaStream_t stream,
const int32_t num_directions,
const int32_t seq_length,
const int32_t batch_size,
const int32_t hidden_size,
const int32_t* sequence_lens,
T* y_output_data,
T* y_h_output_data,
const size_t N);

template <typename T>
void MaskZeroSequences(cudaStream_t stream,
const int32_t hidden_size,
Expand Down
24 changes: 3 additions & 21 deletions onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) {
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);

// TensorRT failed on RNN tests
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) {
#else
TEST(RNNTest, RNN_bidirectional_zigged_batch) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5;

Expand Down Expand Up @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) {
std::vector<float> Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F});
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) {
#else
TEST(RNNTest, RNN_forward_direction_zigged_batch) {
#endif
OpTester test("RNN");
int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5;

Expand Down Expand Up @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_0) {
#else
TEST(RNNTest, RNN_bidirectional_0) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;

Expand Down Expand Up @@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_1) {
#else
TEST(RNNTest, RNN_bidirectional_1) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1;

Expand Down Expand Up @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) {
}
}

TEST(RNNTest, DISABLED_RNN_reverse_direction) {
TEST(RNNTest, RNN_reverse_direction) {
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;

// In case of useDefault, attributes, inputs or outputs are not set.
Expand Down

0 comments on commit efbe2b8

Please sign in to comment.