From efbe2b84556c195e7d7f3353321eb3f410a1e645 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 23 Feb 2024 17:45:17 +0100 Subject: [PATCH] Fix cuDNN v9 build by replacing removed cuDNN v6 RNN API usage by cuDNN v8 RNN API and reenable RNN tests for CUDA EP (#19419) Replace deprecated cuDNN RNN based API by cuDNN v8 RNN API and re-enable RNN tests for the CUDA EP. ### Motivation and Context The deprecated cuDNN RNN API might vanish soon and in addition for the current CUDA EP RNN implementation all RNN tests are disabled due to failures. With this change the deprecated API has been removed and the new updated implemented doesn't fail the tests anymore. --- .../core/providers/cuda/cudnn_common.h | 4 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 350 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 55 +-- onnxruntime/core/providers/cuda/rnn/rnn.cc | 3 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 1 + .../core/providers/cuda/rnn/rnn_impl.cu | 91 +---- .../core/providers/cuda/rnn/rnn_impl.h | 14 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 24 +- 8 files changed, 240 insertions(+), 302 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index fdd14dedad47e..2cbeb13696270 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -24,12 +24,12 @@ class CudnnTensor final { operator cudnnTensorDescriptor_t() const { return tensor_; } + Status CreateTensorIfNeeded(); + template static cudnnDataType_t GetDataType(); private: - Status CreateTensorIfNeeded(); - cudnnTensorDescriptor_t tensor_; }; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 99c1f48e21c74..b61b104790fe5 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -9,40 +9,49 @@ namespace onnxruntime { namespace cuda { template -void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* reorganized_w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const { +Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t reorganized_w_data_size, + const void* reorganized_w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const { int numDims; - std::vector matDims(3); + std::array matDims; + std::array strideA; cudnnDataType_t dt; - cudnnTensorFormat_t tf; T* mem_offset; - if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } + CudnnTensor tensor_desc_matrix, tensor_desc_bias; + ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded()); + ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded()); - cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); + T *mem_offset_matrix, *mem_offset_bias; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams( + handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data, + lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias)); + CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor( + is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); + + mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; int count = matDims[0] * matDims[1] * matDims[2]; + + if (strideA[0] != count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); + } CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + offset += count; + + return Status::OK(); } template Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t reorganized_w_data_size, void* reorganized_w_data, const T* W_data, const T* R_data, @@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int w_offset = 0; int r_offset = 0; int bias_offset = 0; - CudnnFilterDescriptor filter_desc; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias( + cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } } @@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { @@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); - TensorShapeVector fake_dims_x({1, input_size, 1}); - CudnnTensor fake_x_desc; - ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType())); - // Prepare the weight data - reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream); + reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); // In many cases, this allocation is bigger than needed, leaving part of - // the buffer unintialized. non-zero garbage data leads to wrong result + // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); @@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons auto* ort_cuda_stream = dynamic_cast(ort_stream); cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc, - reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, + reorganized_w_data_size_in_bytes, reorganized_w_data.get(), + W_data, R_data, B_data, cuda_stream)); return Status::OK(); } @@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); + bool has_bias = B != nullptr; + if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(), + auto proj_size = hidden_size_; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } cudaStreamSynchronize(nullptr); + weight_cached_ = true; } @@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + // [batch_size] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); + // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); } + size_t proj_size = hidden_size_; int64_t seq_length = X->Shape()[0]; int64_t batch_size = X->Shape()[1]; int64_t input_size = X->Shape()[2]; + // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? + std::vector sequence_lengths_temp; + if (!sequence_lens) { + sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + } + + const int32_t* sequence_lens_data = (sequence_lens == nullptr) + ? sequence_lengths_temp.data() + : sequence_lens->Data(); + + // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + int64_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + + CudaAsyncBuffer sequence_lens_buffer(this, batch_size); + int32_t* seq_len_array = sequence_lens_buffer.CpuPtr(); + + // 0-len sequences are not supported by cuDNN. + // Replace them by sequences of len 1 and mask them out with SetZeroSequences + for (int i = 0; i < batch_size; ++i) { + if (0 == sequence_lens_data[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } else { + seq_len_array[i] = sequence_lens_data[i]; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache.resize(zero_seq_count * num_directions_); + for (int64_t i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[static_cast(zero_seq_count) + i] = + static_cast(batch_size + zero_seq_index_cache[i]); + } + zero_seq_count *= num_directions_; + } + + // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must + // be copied to the GPU always. + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must + // be copied to the GPU only for the ReverseBySequence kernels. + // if (reverse_) { + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // } + // optional outputs TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); @@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc); - std::vector dims_x({batch_size, input_size, 1}); - std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1}); - - CudnnTensor x_desc_temp; - ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType())); - CudnnTensor y_desc_temp; - ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType())); - std::vector x_desc(seq_length, x_desc_temp); - std::vector y_desc(seq_length, y_desc_temp); - - CudnnTensor hx_desc; - CudnnTensor cx_desc; - CudnnTensor y_h_desc; - CudnnTensor y_c_desc; - ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->Data(); if (reverse_) { @@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), reinterpret_cast(x_data), @@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_data = y_alloc_data.get(); } - const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data(); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + bool has_bias = B != nullptr; CudnnRNN rnn_desc; - ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx), + ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size, hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); // Prepare the weight data + size_t w_data_size_in_bytes = 0; IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); const Tensor* B = ctx->Input(RNN_Input_Index::B); - ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, + rnn_desc, ctx->GetComputeStream())); } - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + CudnnDataTensor x_desc1; + ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + input_size, seq_len_array)); + CudnnDataTensor y_desc1; + ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_, + seq_len_array)); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - int64_t zero_seq_count = 0; - std::vector zero_seq_index_cache(batch_size, 0); - int64_t zero_seq_index_cache_size = 0; - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx), - rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } + CudnnTensor cx_desc; + ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } + CudnnTensor hx_desc; + ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); + + // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes + size_t workspace_bytes, reservespace_bytes; - CudnnDataTensor x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - CudnnDataTensor y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { + CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, + x_desc1, &workspace_bytes, &reservespace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), + rnn_desc, + CUDNN_FWD_MODE_INFERENCE, + sequence_lens_buffer.GpuPtr(), // should be zero starting with cudnn 8.9.1 + x_desc1, + x_data_input, + y_desc1, + y_data, // output + hx_desc, + hx_data, // input + y_h_data, // output + cx_desc, cx_data, y_c_data, + weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + workspace_bytes, + workspace_cuda.get(), + reservespace_bytes, + reservespace_cuda.get())); + + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, + // no need the following code to retrieve Y_h from Y data. + if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } + return Status::OK(); } IAllocatorUniquePtr y_reorganized_data; @@ -345,6 +370,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // reverse output data ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), @@ -361,8 +387,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } if (Y != nullptr) { - // User specified this optional output, so need to copy the reversed data to orignial place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); + // User specified this optional output, so need to copy the reversed data to original place + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); } @@ -370,23 +397,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } - if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - CudaAsyncBuffer sequence_lens_buffer(this, batch_size); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); - RnnMaskImpl(Stream(ctx), - gsl::narrow_cast(num_directions_), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), - sequence_lens_buffer.GpuPtr(), - reinterpret_cast(y_data), - reinterpret_cast(y_h_data), - output_size); - } return Status::OK(); } @@ -399,7 +412,8 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, onnxruntime::Stream* ort_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); - memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), + zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; MaskZeroSequences(cuda_stream, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 1c9483b2afd38..0fa01d3486e99 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -38,26 +38,28 @@ class CudnnRNN { } } - Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, + Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); - CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle, - cudnn_rnn_desc_, + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC + rnn_mode, + has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS, + cudnn_direction_model, + CUDNN_LINEAR_INPUT, + dataType, + dataType, + dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), + gsl::narrow_cast(proj_size), // projected size num_layers, cudnn_dropout_desc, - CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation - cudnn_direction_model, - rnn_mode, - CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC - dataType)); - - if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) { - cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH); - } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RNN_PADDED_IO_ENABLED)); return Status::OK(); } @@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel { private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t w_data_size, void* w_data, const T* W_data, const T* R_data, @@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel { cudaStream_t cuda_stream) const; Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& target_w_data_size_in_bytes, IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const; - void SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const; + Status SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t w_data_size, + const void* w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const; void SetZeroSequences(const int64_t zero_seq_index_cache_size, const std::vector zero_seq_index_cache, @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel { cudnnRNNMode_t rnn_mode_; // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; + size_t w_data_cache_size_in_bytes_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; int64_t layout_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 4bd22340ef2bb..ed8be63679707 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include "rnn.h" + +#include "core/providers/shared_library/provider_api.h" #include "rnn_impl.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index e4e50046b3725..6221afb003b22 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -4,6 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" + #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index d485855ddb417..94c8036be6cdf 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -8,22 +8,32 @@ namespace onnxruntime { namespace cuda { template -__global__ void _ReverseBySequenceKernel(const int32_t seq_length, +__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t block_size, const fast_divmod div_batch_block, + const fast_divmod div_input_or_hidden_size, const T* data, T* reversed_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int seq_id, offset; div_batch_block.divmod(id, seq_id, offset); - int org_id = (seq_length - seq_id - 1) * block_size + offset; - reversed_data[id] = data[org_id]; + int batch, batch_offset; + div_input_or_hidden_size.divmod(offset, batch, batch_offset); + int seq_id_org = seq_lengths[batch] - seq_id - 1; + if (seq_id_org >= 0) { + int org_id = seq_id_org * block_size + offset; + reversed_data[id] = data[org_id]; + } else { + reversed_data[id] = T{}; + } } template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t *seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream, // kerneral int32_t block_size = batch_size * input_or_hidden_size; fast_divmod div_batch_block(block_size); + fast_divmod div_input_or_hidden_size(input_or_hidden_size); int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); _ReverseBySequenceKernel<<>>( - seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N); + max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N); } template @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, data, reordered_data, (CUDA_LONG)N); } -template -__global__ void _RnnMaskKernel(const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - const fast_divmod div_seq_block, - const fast_divmod div_dir_block, - const fast_divmod div_batch_block, - T* y_output_data, - T* y_h_output_data, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - - int seq_id, direction_id, batch_id, offset; - div_seq_block.divmod(id, seq_id, offset); - div_dir_block.divmod(offset, direction_id, offset); - div_batch_block.divmod(offset, batch_id, offset); - int32_t batch_seq_length = sequence_lens[batch_id]; - - if (batch_id >= batch_size || batch_seq_length == seq_length) { - return; - } - - if (seq_id >= batch_seq_length) { - y_output_data[id] = 0; - return; - } - - if ((y_h_output_data != nullptr) && - ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) { - int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset; - y_h_output_data[hy_idx] = y_output_data[id]; - } -} - -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N) { - fast_divmod div_seq_block(batch_size * hidden_size * num_directions); - fast_divmod div_dir_block(batch_size * hidden_size); - fast_divmod div_batch_block(hidden_size); - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _RnnMaskKernel<<>>( - seq_length, batch_size, hidden_size, sequence_lens, div_seq_block, - div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); -} - template __global__ void _MaskZeroSequences(const int32_t hidden_size, T* y_output_data, @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream, } #define SPECIALIZED_RNN_IMPL(T) \ - template void RnnMaskImpl(cudaStream_t stream, \ - const int32_t num_directions, \ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const int32_t* sequence_lens, \ - T* y_output_data, \ - T* y_h_output_data, \ - const size_t N); \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t seq_length, \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ const int32_t batch_size, \ const int32_t hidden_size, \ const T* data, \ @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream, const T* data, \ T* reordered_data, \ const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ +template void MaskZeroSequences(cudaStream_t stream, \ const int32_t hidden_size, \ T* y_output_data, \ T* y_h_output_data, \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 9844e04ff6ec5..ba876011f6b67 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -10,7 +10,8 @@ namespace cuda { template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, T* reordered_data, const size_t N); -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N); - template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index b9875b9553a55..1a31743e2f7e7 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // TensorRT failed on RNN tests - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) { -#else TEST(RNNTest, RNN_bidirectional_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5; @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) { std::vector Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F}); test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) { -#else TEST(RNNTest, RNN_forward_direction_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5; @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_0) { -#else TEST(RNNTest, RNN_bidirectional_0) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; @@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_1) { -#else TEST(RNNTest, RNN_bidirectional_1) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1; @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) { } } -TEST(RNNTest, DISABLED_RNN_reverse_direction) { +TEST(RNNTest, RNN_reverse_direction) { int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; // In case of useDefault, attributes, inputs or outputs are not set.