From dd9933d0c11494fc5184a30dcf9e5c9a2cf84ec3 Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:38:15 +0800 Subject: [PATCH] rope support 4D input tensor (#18454) ### Description change RotaryEmbeddings op implementation, add support for 4D input tensor that is with shape of [batch, num_heads, seq_len, head_size]. ### Motivation and Context Current RotaryEmbedding op only support 3d input tensor with shape [batch, seq_len, hidden_size] For llamav2 model, when using FusionRotaryEmbeddings to only fuse RotaryEmbeddings op, there will be a transpose operation for query and key, and then the input tensor of RotaryEmbeddings becomes 4D [batch, num_heads, seq_len, head_size]. This scenario can't be supported by current RotaryEmbeddings implementation. So it needs to support 4D input tensor. --- docs/ContribOperators.md | 4 +- .../contrib_ops/cpu/bert/rotary_embedding.cc | 17 +++++-- .../cpu/bert/rotary_embedding_helper.h | 16 +++++-- .../contrib_ops/cuda/bert/rotary_embedding.cc | 3 +- .../cuda/bert/rotary_embedding_impl.cu | 35 ++++++++++---- .../cuda/bert/rotary_embedding_impl.h | 3 +- .../core/graph/contrib_ops/bert_defs.cc | 4 +- .../test_parity_rotary_embedding.py | 47 +++++++++++++++++-- 8 files changed, 103 insertions(+), 26 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index da900e5c59405..8565ffbb6c379 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5023,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr
input : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
@@ -5036,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
-
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
tensor with same shape as input.
#### Type Constraints diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index 4a266af789250..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; const int half_head_size = head_size / 2; + // Default input tensor shape is [batch, seq_len, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (parameters.transposed) { + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); @@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int s = static_cast((ptr / num_heads) % sequence_length); const int n = static_cast(ptr % num_heads); - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input_src + data_offset; - T* output_data = output_dest + data_offset; + const T* input_data = input_src + block_offset; + T* output_data = output_dest + block_offset; // Cache is (M, H/2) const int position_id = (position_ids_format == 0) diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index cf8080800e072..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -18,6 +18,7 @@ struct RotaryParameters { int num_heads; // num_heads = hidden_size / head_size int max_sequence_length; // Sequence length used by cos/sin cache int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -33,8 +34,8 @@ Status CheckInputs(const T* input, // Check input const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + if (input_dims.size() != 3 && input_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ", input_dims.size()); } // Check position_ids @@ -63,6 +64,14 @@ Status CheckInputs(const T* input, int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); int hidden_size = static_cast(input_dims[2]); + + bool transposed = false; + if (input_dims.size() == 4) { + // input is [batch, num_heads, seq, head_size] + sequence_length = static_cast(input_dims[2]); + hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]); + transposed = true; + } int max_sequence_length = static_cast(cos_cache_dims[0]); int head_size = static_cast(cos_cache_dims[1]) * 2; int num_heads = hidden_size / head_size; @@ -111,6 +120,7 @@ Status CheckInputs(const T* input, output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; + output_parameters->transposed = transposed; } return Status::OK(); @@ -118,4 +128,4 @@ Status CheckInputs(const T* input, } // namespace rotary_embedding_helper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index b4b5dac1fbe19..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.max_sequence_length, parameters.position_ids_format, interleaved, - device_prop.maxThreadsPerBlock); + device_prop.maxThreadsPerBlock, + parameters.transposed); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c54e72dcfce13..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int num_heads, const int head_size, const int position_ids_format, - const bool interleaved) { + const bool interleaved, + const int batch_stride, + const int seq_stride, + const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int i = threadIdx.x; - const int block_offset = b * sequence_length * num_heads + s * num_heads + n; - const int data_offset = block_offset * head_size; + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - const T* input_data = input + data_offset; - T* output_data = output + data_offset; + const T* input_data = input + block_offset; + T* output_data = output + block_offset; // Cache is (M, H/2) const int half_head_size = head_size / 2; @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool transposed) { constexpr int smem_size = 0; const dim3 grid(num_heads, sequence_length, batch_size); @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel( // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. + // Default input tensor shape is [batch, seq, hidden_size] + int head_stride = head_size; + int seq_stride = num_heads * head_stride; + int batch_stride = sequence_length * seq_stride; + if (transposed) { + // When transposed, input tensor shape is [batch, num_heads, seq, head_size] + seq_stride = head_size; + head_stride = sequence_length * seq_stride; + batch_stride = num_heads * head_stride; + } + assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>( output, input, cos_cache, sin_cache, position_ids, - sequence_length, num_heads, head_size, position_ids_format, interleaved + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); template Status LaunchRotaryEmbeddingKernel( cudaStream_t stream, @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 29ff48a8ad0fb..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel( const int max_sequence_length, const int position_ids_format, const bool interleaved, - const int max_threads_per_block); + const int max_threads_per_block, + const bool transposed); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index a99bb36984538..b97fb0d2899fc 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OPTIONAL_VALUE) .Input(0, "input", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", "T") .Input(1, "position_ids", @@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Output(0, "output", - "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "tensor with same shape as input.", "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py index b17ae5f69aff5..cf8128e0eebcf 100644 --- a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -261,14 +261,15 @@ def get_eps(self): eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) - def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False): eps = self.get_eps() for ep in eps: sess = ort.InferenceSession(onnx_graph, providers=[ep]) output_ort = sess.run(None, inputs_ort)[0] - output_ort = output_ort.reshape( - (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) - ) + if not transposed: + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) # Compare outputs as BxSxNxH self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) @@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self): # Compare outputs as BxSxNxH self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed + def test_hf_prompt_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxNxSxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed + def test_hf_token_rotary_one_pos_id_transposed(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH + + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bnsh.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Set tranposed=True to compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True) + if __name__ == "__main__": unittest.main()