From 1a2946091968fad57e52dd632967a870e0265b06 Mon Sep 17 00:00:00 2001
From: kailums <109063327+kailums@users.noreply.github.com>
Date: Fri, 17 Nov 2023 20:38:15 +0800
Subject: [PATCH] rope support 4D input tensor (#18454)
### Description
change RotaryEmbeddings op implementation, add support for 4D input
tensor that is with shape of [batch, num_heads, seq_len, head_size].
### Motivation and Context
Current RotaryEmbedding op only support 3d input tensor with shape
[batch, seq_len, hidden_size]
For llamav2 model, when using FusionRotaryEmbeddings to only fuse
RotaryEmbeddings op, there will be a transpose operation for query and
key, and then the input tensor of RotaryEmbeddings becomes 4D [batch,
num_heads, seq_len, head_size].
This scenario can't be supported by current RotaryEmbeddings
implementation. So it needs to support 4D input tensor.
---
docs/ContribOperators.md | 4 +-
.../contrib_ops/cpu/bert/rotary_embedding.cc | 17 +++++--
.../cpu/bert/rotary_embedding_helper.h | 16 +++++--
.../contrib_ops/cuda/bert/rotary_embedding.cc | 3 +-
.../cuda/bert/rotary_embedding_impl.cu | 35 ++++++++++----
.../cuda/bert/rotary_embedding_impl.h | 3 +-
.../core/graph/contrib_ops/bert_defs.cc | 4 +-
.../test_parity_rotary_embedding.py | 47 +++++++++++++++++--
8 files changed, 103 insertions(+), 26 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index da900e5c59405..8565ffbb6c379 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -5023,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- input : T
-- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+- 3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)
- position_ids : M
- 1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
- cos_cache : T
@@ -5036,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- output : T
-- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+- tensor with same shape as input.
#### Type Constraints
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 4a266af789250..47f462d75fcc4 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -63,6 +63,16 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;
+ // Default input tensor shape is [batch, seq_len, hidden_size]
+ int head_stride = head_size;
+ int seq_stride = num_heads * head_stride;
+ int batch_stride = sequence_length * seq_stride;
+ if (parameters.transposed) {
+ // Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
+ seq_stride = head_size;
+ head_stride = sequence_length * seq_stride;
+ batch_stride = num_heads * head_stride;
+ }
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
@@ -76,11 +86,10 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int s = static_cast((ptr / num_heads) % sequence_length);
const int n = static_cast(ptr % num_heads);
- const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
- const int data_offset = block_offset * head_size;
+ const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
- const T* input_data = input_src + data_offset;
- T* output_data = output_dest + data_offset;
+ const T* input_data = input_src + block_offset;
+ T* output_data = output_dest + block_offset;
// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
index cf8080800e072..7b2e8289f7b06 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -18,6 +18,7 @@ struct RotaryParameters {
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+ bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
template
@@ -33,8 +34,8 @@ Status CheckInputs(const T* input,
// Check input
const auto& input_dims = input->Shape().GetDims();
- if (input_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
+ if (input_dims.size() != 3 && input_dims.size() != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ",
input_dims.size());
}
// Check position_ids
@@ -63,6 +64,14 @@ Status CheckInputs(const T* input,
int batch_size = static_cast(input_dims[0]);
int sequence_length = static_cast(input_dims[1]);
int hidden_size = static_cast(input_dims[2]);
+
+ bool transposed = false;
+ if (input_dims.size() == 4) {
+ // input is [batch, num_heads, seq, head_size]
+ sequence_length = static_cast(input_dims[2]);
+ hidden_size = static_cast(input_dims[1]) * static_cast(input_dims[3]);
+ transposed = true;
+ }
int max_sequence_length = static_cast(cos_cache_dims[0]);
int head_size = static_cast(cos_cache_dims[1]) * 2;
int num_heads = hidden_size / head_size;
@@ -111,6 +120,7 @@ Status CheckInputs(const T* input,
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
+ output_parameters->transposed = transposed;
}
return Status::OK();
@@ -118,4 +128,4 @@ Status CheckInputs(const T* input,
} // namespace rotary_embedding_helper
} // namespace contrib
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index b4b5dac1fbe19..2d12e975d88d7 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -74,7 +74,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
- device_prop.maxThreadsPerBlock);
+ device_prop.maxThreadsPerBlock,
+ parameters.transposed);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
index c54e72dcfce13..e1b83bd8caf54 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int num_heads,
const int head_size,
const int position_ids_format,
- const bool interleaved) {
+ const bool interleaved,
+ const int batch_stride,
+ const int seq_stride,
+ const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
@@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int i = threadIdx.x;
- const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
- const int data_offset = block_offset * head_size;
+ const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
- const T* input_data = input + data_offset;
- T* output_data = output + data_offset;
+ const T* input_data = input + block_offset;
+ T* output_data = output + block_offset;
// Cache is (M, H/2)
const int half_head_size = head_size / 2;
@@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
- const int max_threads_per_block) {
+ const int max_threads_per_block,
+ const bool transposed) {
constexpr int smem_size = 0;
const dim3 grid(num_heads, sequence_length, batch_size);
@@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel(
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
+ // Default input tensor shape is [batch, seq, hidden_size]
+ int head_stride = head_size;
+ int seq_stride = num_heads * head_stride;
+ int batch_stride = sequence_length * seq_stride;
+ if (transposed) {
+ // When transposed, input tensor shape is [batch, num_heads, seq, head_size]
+ seq_stride = head_size;
+ head_stride = sequence_length * seq_stride;
+ batch_stride = num_heads * head_stride;
+ }
+
assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<>>(
output, input, cos_cache, sin_cache, position_ids,
- sequence_length, num_heads, head_size, position_ids_format, interleaved
+ sequence_length, num_heads, head_size, position_ids_format, interleaved,
+ batch_stride, seq_stride, head_stride
);
return CUDA_CALL(cudaGetLastError());
@@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
- const int max_threads_per_block);
+ const int max_threads_per_block,
+ const bool transposed);
template Status LaunchRotaryEmbeddingKernel(
cudaStream_t stream,
@@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
- const int max_threads_per_block);
+ const int max_threads_per_block,
+ const bool transposed);
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
index 29ff48a8ad0fb..ee1ccc43dcbff 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
- const int max_threads_per_block);
+ const int max_threads_per_block,
+ const bool transposed);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index a99bb36984538..b97fb0d2899fc 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OPTIONAL_VALUE)
.Input(0,
"input",
- "3D tensor with shape (batch_size, sequence_length, hidden_size)",
+ "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
"T")
.Input(1,
"position_ids",
@@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T")
.Output(0,
"output",
- "3D tensor with shape (batch_size, sequence_length, hidden_size)",
+ "tensor with same shape as input.",
"T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py
index b17ae5f69aff5..cf8128e0eebcf 100644
--- a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py
+++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py
@@ -261,14 +261,15 @@ def get_eps(self):
eps = ["CPUExecutionProvider", "CUDAExecutionProvider"]
return list(filter(lambda ep: ep in ort.get_available_providers(), eps))
- def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh):
+ def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False):
eps = self.get_eps()
for ep in eps:
sess = ort.InferenceSession(onnx_graph, providers=[ep])
output_ort = sess.run(None, inputs_ort)[0]
- output_ort = output_ort.reshape(
- (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size)
- )
+ if not transposed:
+ output_ort = output_ort.reshape(
+ (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size)
+ )
# Compare outputs as BxSxNxH
self.assertTrue(np.allclose(expected_output_bsnh, output_ort))
@@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self):
# Compare outputs as BxSxNxH
self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())
+ # Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed
+ def test_hf_prompt_rotary_one_pos_id_transposed(self):
+ x_bnsh = torch.randn(
+ self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
+ )
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH
+
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = torch.tensor([0])
+ onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bnsh.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Compare outputs as BxNxSxH
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True)
+
+ # Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed
+ def test_hf_token_rotary_one_pos_id_transposed(self):
+ x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size)
+ cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
+ pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)])
+ output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH
+
+ cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
+ pos_ms = torch.tensor([2])
+ onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
+ inputs_ort = {
+ "input": x_bnsh.detach().cpu().numpy(),
+ "position_ids": pos_ms.detach().cpu().numpy(),
+ }
+
+ # Set tranposed=True to compare outputs as BxSxNxH
+ self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True)
+
if __name__ == "__main__":
unittest.main()