Skip to content

Commit

Permalink
Replace deprecated cuDNN RNN APIs by new cuDNN v8 APIs and re-enable …
Browse files Browse the repository at this point in the history
…RNN tests which have been broken before.
  • Loading branch information
mtavenrath committed Feb 5, 2024
1 parent 435e199 commit 301e345
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 139 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/cudnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class CudnnTensor final {

operator cudnnTensorDescriptor_t() const { return tensor_; }

Status CreateTensorIfNeeded();

template <typename T>
static cudnnDataType_t GetDataType();

private:
Status CreateTensorIfNeeded();

cudnnTensorDescriptor_t tensor_;
};

Expand Down
206 changes: 107 additions & 99 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc

Large diffs are not rendered by default.

42 changes: 28 additions & 14 deletions onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

#include <cudnn.h>

#if (CUDNN_MAJOR >= 8)

#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/cudnn_common.h"


namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -38,26 +41,29 @@ class CudnnRNN {
}
}

Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers,

Status Set(const cudnnHandle_t& cudnnHandle, int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers,
cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model,
cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) {
cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType, const cudaDeviceProp& prop) {
if (!cudnn_rnn_desc_)
CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_));

CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle,
cudnn_rnn_desc_,
CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_,
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
rnn_mode,
has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS,
cudnn_direction_model,
CUDNN_LINEAR_INPUT,
dataType,
dataType,
dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH,
gsl::narrow_cast<int>(input_size),
gsl::narrow_cast<int>(hidden_size),
gsl::narrow_cast<int>(proj_size), // projected size
num_layers,
cudnn_dropout_desc,
CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation
cudnn_direction_model,
rnn_mode,
CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC
dataType));

if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) {
cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH);
}
// CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences
CUDNN_RNN_PADDED_IO_ENABLED));

return Status::OK();
}
Expand Down Expand Up @@ -121,24 +127,27 @@ class CudnnRnnBase : public CudaKernel {
const cudnnRNNDescriptor_t rnn_desc,
const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc,
size_t w_data_size,
void* w_data,
const T* W_data,
const T* R_data,
const T* B_data,
cudaStream_t cuda_stream) const;

Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B,
size_t& target_w_data_size_in_bytes,
IAllocatorUniquePtr<void>& target_w_data,
CudnnFilterDescriptor& target_w_desc,
CudnnRNN& rnn_desc,
onnxruntime::Stream* ort_stream) const;

void SetWeightBias(const cudnnHandle_t handle,
Status SetWeightBias(const cudnnHandle_t handle,
const cudnnRNNDescriptor_t rnn_desc,
const int pseudo_layer,
const cudnnTensorDescriptor_t x_desc,
const cudnnFilterDescriptor_t w_desc,
const cudnnFilterDescriptor_t filter_desc,
size_t w_data_size,
const void* w_data,
const int lin_layer_id,
const T* pos,
Expand All @@ -162,11 +171,14 @@ class CudnnRnnBase : public CudaKernel {
cudnnDirectionMode_t cudnn_direction_mode_;
bool reverse_;
int64_t num_directions_;
// input_size_ from attribute
int64_t input_size_;
// hidden_size_ from attribute
int64_t hidden_size_;
cudnnRNNMode_t rnn_mode_;
// w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input
CudnnFilterDescriptor w_desc_cache_;
size_t w_data_cache_size_in_bytes_;
IAllocatorUniquePtr<void> w_data_cache_;
bool weight_cached_;
int64_t layout_;
Expand All @@ -184,3 +196,5 @@ class CudnnRnnBase : public CudaKernel {

} // namespace cuda
} // namespace onnxruntime

#endif // CUDNN_MAJOR
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/rnn/rnn.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "rnn.h"

#if CUDNN_MAJOR >= 8

#include "core/providers/shared_library/provider_api.h"
#include "rnn_impl.h"
#include "core/providers/cuda/cudnn_common.h"

Expand Down Expand Up @@ -46,3 +49,5 @@ REGISTER_KERNEL_TYPED(MLFloat16);

} // namespace cuda
} // namespace onnxruntime

#endif
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cuda/rnn/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#pragma once

#include "cudnn_rnn_base.h"

#if CUDNN_MAJOR >= 8

#include "core/providers/cuda/cuda_common.h"
#include <cudnn.h>

Expand All @@ -30,9 +33,11 @@ class RNN final : public CudnnRnnBase<T> {
// ONNX B layout is Wb, Rb, mapping to RNNLinLayerMatrixParams
// the linLayerID is 0, 1, we can reuse it from W_lin_layer_id & R_lin_layer_id

ORT_THROW_IF_ERROR(CudnnRnnBase<T>::CacheCudnnRnnWeights(info));
//ORT_THROW_IF_ERROR(CudnnRnnBase<T>::CacheCudnnRnnWeights(info));
}
};

} // namespace cuda
} // namespace onnxruntime

#endif
26 changes: 4 additions & 22 deletions onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) {
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);

// TensorRT failed on RNN tests
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) {
#else
TEST(RNNTest, RNN_bidirectional_zigged_batch) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5;

Expand Down Expand Up @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) {
std::vector<float> Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F});
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider});
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) {
#else
TEST(RNNTest, RNN_forward_direction_zigged_batch) {
#endif
OpTester test("RNN");
int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5;

Expand Down Expand Up @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_0) {
#else
TEST(RNNTest, RNN_bidirectional_0) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;

Expand Down Expand Up @@ -422,14 +409,9 @@ TEST(RNNTest, RNN_bidirectional_0) {
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);

test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}

// Doesn't work with CUDA 11.4 on Windows. Need investigation.
#if defined(USE_CUDA) && defined(_WIN32)
TEST(RNNTest, DISABLED_RNN_bidirectional_1) {
#else
TEST(RNNTest, RNN_bidirectional_1) {
#endif
OpTester test("RNN");
int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1;

Expand Down Expand Up @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) {
}
}

TEST(RNNTest, DISABLED_RNN_reverse_direction) {
TEST(RNNTest, RNN_reverse_direction) {
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5;

// In case of useDefault, attributes, inputs or outputs are not set.
Expand Down

0 comments on commit 301e345

Please sign in to comment.