From 3a4599792dc3217b1b689344cc75e41804d06597 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Wed, 28 Feb 2024 10:20:01 -0800 Subject: [PATCH] Cherry pick for 1.17.2 source code release (#19679) ### Description As title. ### Motivation and Context Co-authored-by: Markus Tavenrath --- .../core/providers/cuda/cudnn_common.h | 4 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 350 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 55 +-- onnxruntime/core/providers/cuda/rnn/rnn.cc | 3 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 1 + .../core/providers/cuda/rnn/rnn_impl.cu | 91 +---- .../core/providers/cuda/rnn/rnn_impl.h | 14 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 24 +- 8 files changed, 240 insertions(+), 302 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index 8a94a334ee688..1043e56f3a8a7 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -24,12 +24,12 @@ class CudnnTensor final { operator cudnnTensorDescriptor_t() const { return tensor_; } + Status CreateTensorIfNeeded(); + template static cudnnDataType_t GetDataType(); private: - Status CreateTensorIfNeeded(); - cudnnTensorDescriptor_t tensor_; }; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 99c1f48e21c74..b61b104790fe5 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -9,40 +9,49 @@ namespace onnxruntime { namespace cuda { template -void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* reorganized_w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const { +Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t reorganized_w_data_size, + const void* reorganized_w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const { int numDims; - std::vector matDims(3); + std::array matDims; + std::array strideA; cudnnDataType_t dt; - cudnnTensorFormat_t tf; T* mem_offset; - if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } + CudnnTensor tensor_desc_matrix, tensor_desc_bias; + ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded()); + ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded()); - cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); + T *mem_offset_matrix, *mem_offset_bias; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams( + handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data, + lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias)); + CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor( + is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); + + mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; int count = matDims[0] * matDims[1] * matDims[2]; + + if (strideA[0] != count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); + } CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + offset += count; + + return Status::OK(); } template Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t reorganized_w_data_size, void* reorganized_w_data, const T* W_data, const T* R_data, @@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int w_offset = 0; int r_offset = 0; int bias_offset = 0; - CudnnFilterDescriptor filter_desc; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias( + cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } } @@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { @@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); - TensorShapeVector fake_dims_x({1, input_size, 1}); - CudnnTensor fake_x_desc; - ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType())); - // Prepare the weight data - reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream); + reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); // In many cases, this allocation is bigger than needed, leaving part of - // the buffer unintialized. non-zero garbage data leads to wrong result + // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); @@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons auto* ort_cuda_stream = dynamic_cast(ort_stream); cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc, - reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, + reorganized_w_data_size_in_bytes, reorganized_w_data.get(), + W_data, R_data, B_data, cuda_stream)); return Status::OK(); } @@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); + bool has_bias = B != nullptr; + if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(), + auto proj_size = hidden_size_; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } cudaStreamSynchronize(nullptr); + weight_cached_ = true; } @@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + // [batch_size] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); + // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); } + size_t proj_size = hidden_size_; int64_t seq_length = X->Shape()[0]; int64_t batch_size = X->Shape()[1]; int64_t input_size = X->Shape()[2]; + // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? + std::vector sequence_lengths_temp; + if (!sequence_lens) { + sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + } + + const int32_t* sequence_lens_data = (sequence_lens == nullptr) + ? sequence_lengths_temp.data() + : sequence_lens->Data(); + + // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + int64_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + + CudaAsyncBuffer sequence_lens_buffer(this, batch_size); + int32_t* seq_len_array = sequence_lens_buffer.CpuPtr(); + + // 0-len sequences are not supported by cuDNN. + // Replace them by sequences of len 1 and mask them out with SetZeroSequences + for (int i = 0; i < batch_size; ++i) { + if (0 == sequence_lens_data[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } else { + seq_len_array[i] = sequence_lens_data[i]; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache.resize(zero_seq_count * num_directions_); + for (int64_t i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[static_cast(zero_seq_count) + i] = + static_cast(batch_size + zero_seq_index_cache[i]); + } + zero_seq_count *= num_directions_; + } + + // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must + // be copied to the GPU always. + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must + // be copied to the GPU only for the ReverseBySequence kernels. + // if (reverse_) { + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // } + // optional outputs TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); @@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc); - std::vector dims_x({batch_size, input_size, 1}); - std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1}); - - CudnnTensor x_desc_temp; - ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType())); - CudnnTensor y_desc_temp; - ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType())); - std::vector x_desc(seq_length, x_desc_temp); - std::vector y_desc(seq_length, y_desc_temp); - - CudnnTensor hx_desc; - CudnnTensor cx_desc; - CudnnTensor y_h_desc; - CudnnTensor y_c_desc; - ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->Data(); if (reverse_) { @@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), reinterpret_cast(x_data), @@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_data = y_alloc_data.get(); } - const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data(); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + bool has_bias = B != nullptr; CudnnRNN rnn_desc; - ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx), + ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size, hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); // Prepare the weight data + size_t w_data_size_in_bytes = 0; IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); const Tensor* B = ctx->Input(RNN_Input_Index::B); - ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, + rnn_desc, ctx->GetComputeStream())); } - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + CudnnDataTensor x_desc1; + ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + input_size, seq_len_array)); + CudnnDataTensor y_desc1; + ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_, + seq_len_array)); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - int64_t zero_seq_count = 0; - std::vector zero_seq_index_cache(batch_size, 0); - int64_t zero_seq_index_cache_size = 0; - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx), - rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } + CudnnTensor cx_desc; + ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } + CudnnTensor hx_desc; + ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); + + // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes + size_t workspace_bytes, reservespace_bytes; - CudnnDataTensor x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - CudnnDataTensor y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { + CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, + x_desc1, &workspace_bytes, &reservespace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), + rnn_desc, + CUDNN_FWD_MODE_INFERENCE, + sequence_lens_buffer.GpuPtr(), // should be zero starting with cudnn 8.9.1 + x_desc1, + x_data_input, + y_desc1, + y_data, // output + hx_desc, + hx_data, // input + y_h_data, // output + cx_desc, cx_data, y_c_data, + weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + workspace_bytes, + workspace_cuda.get(), + reservespace_bytes, + reservespace_cuda.get())); + + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, + // no need the following code to retrieve Y_h from Y data. + if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } + return Status::OK(); } IAllocatorUniquePtr y_reorganized_data; @@ -345,6 +370,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // reverse output data ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), @@ -361,8 +387,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } if (Y != nullptr) { - // User specified this optional output, so need to copy the reversed data to orignial place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); + // User specified this optional output, so need to copy the reversed data to original place + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); } @@ -370,23 +397,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } - if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - CudaAsyncBuffer sequence_lens_buffer(this, batch_size); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); - RnnMaskImpl(Stream(ctx), - gsl::narrow_cast(num_directions_), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), - sequence_lens_buffer.GpuPtr(), - reinterpret_cast(y_data), - reinterpret_cast(y_h_data), - output_size); - } return Status::OK(); } @@ -399,7 +412,8 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, onnxruntime::Stream* ort_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); - memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), + zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; MaskZeroSequences(cuda_stream, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 1c9483b2afd38..0fa01d3486e99 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -38,26 +38,28 @@ class CudnnRNN { } } - Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, + Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); - CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle, - cudnn_rnn_desc_, + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC + rnn_mode, + has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS, + cudnn_direction_model, + CUDNN_LINEAR_INPUT, + dataType, + dataType, + dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), + gsl::narrow_cast(proj_size), // projected size num_layers, cudnn_dropout_desc, - CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation - cudnn_direction_model, - rnn_mode, - CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC - dataType)); - - if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) { - cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH); - } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RNN_PADDED_IO_ENABLED)); return Status::OK(); } @@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel { private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t w_data_size, void* w_data, const T* W_data, const T* R_data, @@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel { cudaStream_t cuda_stream) const; Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& target_w_data_size_in_bytes, IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const; - void SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const; + Status SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t w_data_size, + const void* w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const; void SetZeroSequences(const int64_t zero_seq_index_cache_size, const std::vector zero_seq_index_cache, @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel { cudnnRNNMode_t rnn_mode_; // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; + size_t w_data_cache_size_in_bytes_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; int64_t layout_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 4bd22340ef2bb..ed8be63679707 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include "rnn.h" + +#include "core/providers/shared_library/provider_api.h" #include "rnn_impl.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index e4e50046b3725..6221afb003b22 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -4,6 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" + #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index d485855ddb417..94c8036be6cdf 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -8,22 +8,32 @@ namespace onnxruntime { namespace cuda { template -__global__ void _ReverseBySequenceKernel(const int32_t seq_length, +__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t block_size, const fast_divmod div_batch_block, + const fast_divmod div_input_or_hidden_size, const T* data, T* reversed_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int seq_id, offset; div_batch_block.divmod(id, seq_id, offset); - int org_id = (seq_length - seq_id - 1) * block_size + offset; - reversed_data[id] = data[org_id]; + int batch, batch_offset; + div_input_or_hidden_size.divmod(offset, batch, batch_offset); + int seq_id_org = seq_lengths[batch] - seq_id - 1; + if (seq_id_org >= 0) { + int org_id = seq_id_org * block_size + offset; + reversed_data[id] = data[org_id]; + } else { + reversed_data[id] = T{}; + } } template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t *seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream, // kerneral int32_t block_size = batch_size * input_or_hidden_size; fast_divmod div_batch_block(block_size); + fast_divmod div_input_or_hidden_size(input_or_hidden_size); int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); _ReverseBySequenceKernel<<>>( - seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N); + max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N); } template @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, data, reordered_data, (CUDA_LONG)N); } -template -__global__ void _RnnMaskKernel(const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - const fast_divmod div_seq_block, - const fast_divmod div_dir_block, - const fast_divmod div_batch_block, - T* y_output_data, - T* y_h_output_data, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - - int seq_id, direction_id, batch_id, offset; - div_seq_block.divmod(id, seq_id, offset); - div_dir_block.divmod(offset, direction_id, offset); - div_batch_block.divmod(offset, batch_id, offset); - int32_t batch_seq_length = sequence_lens[batch_id]; - - if (batch_id >= batch_size || batch_seq_length == seq_length) { - return; - } - - if (seq_id >= batch_seq_length) { - y_output_data[id] = 0; - return; - } - - if ((y_h_output_data != nullptr) && - ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) { - int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset; - y_h_output_data[hy_idx] = y_output_data[id]; - } -} - -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N) { - fast_divmod div_seq_block(batch_size * hidden_size * num_directions); - fast_divmod div_dir_block(batch_size * hidden_size); - fast_divmod div_batch_block(hidden_size); - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _RnnMaskKernel<<>>( - seq_length, batch_size, hidden_size, sequence_lens, div_seq_block, - div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); -} - template __global__ void _MaskZeroSequences(const int32_t hidden_size, T* y_output_data, @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream, } #define SPECIALIZED_RNN_IMPL(T) \ - template void RnnMaskImpl(cudaStream_t stream, \ - const int32_t num_directions, \ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const int32_t* sequence_lens, \ - T* y_output_data, \ - T* y_h_output_data, \ - const size_t N); \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t seq_length, \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ const int32_t batch_size, \ const int32_t hidden_size, \ const T* data, \ @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream, const T* data, \ T* reordered_data, \ const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ +template void MaskZeroSequences(cudaStream_t stream, \ const int32_t hidden_size, \ T* y_output_data, \ T* y_h_output_data, \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 9844e04ff6ec5..ba876011f6b67 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -10,7 +10,8 @@ namespace cuda { template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, T* reordered_data, const size_t N); -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N); - template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index b9875b9553a55..1a31743e2f7e7 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // TensorRT failed on RNN tests - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) { -#else TEST(RNNTest, RNN_bidirectional_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5; @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) { std::vector Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F}); test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) { -#else TEST(RNNTest, RNN_forward_direction_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5; @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_0) { -#else TEST(RNNTest, RNN_bidirectional_0) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; @@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_1) { -#else TEST(RNNTest, RNN_bidirectional_1) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1; @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) { } } -TEST(RNNTest, DISABLED_RNN_reverse_direction) { +TEST(RNNTest, RNN_reverse_direction) { int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; // In case of useDefault, attributes, inputs or outputs are not set.