Skip to content

Commit

Permalink
rope support 4D input tensor (microsoft#18454)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

change RotaryEmbeddings op implementation, add support for 4D input
tensor that is with shape of [batch, num_heads, seq_len, head_size].

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Current RotaryEmbedding op only support 3d input tensor with shape
[batch, seq_len, hidden_size]

For llamav2 model, when using FusionRotaryEmbeddings to only fuse
RotaryEmbeddings op, there will be a transpose operation for query and
key, and then the input tensor of RotaryEmbeddings becomes 4D [batch,
num_heads, seq_len, head_size].

This scenario can't be supported by current RotaryEmbeddings
implementation. So it needs to support 4D input tensor.
  • Loading branch information
kailums authored and kleiti committed Mar 22, 2024
1 parent 512a5da commit dd9933d
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 26 deletions.
4 changes: 2 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -5023,7 +5023,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>input</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)</dd>
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
Expand All @@ -5036,7 +5036,7 @@ This version of the operator has been available since version 1 of the 'com.micr

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dd>tensor with same shape as input.</dd>
</dl>

#### Type Constraints
Expand Down
17 changes: 13 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
// Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
Expand All @@ -76,11 +86,10 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
Expand Down
16 changes: 13 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct RotaryParameters {
int num_heads; // num_heads = hidden_size / head_size
int max_sequence_length; // Sequence length used by cos/sin cache
int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};

template <typename T>
Expand All @@ -33,8 +34,8 @@ Status CheckInputs(const T* input,

// Check input
const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
if (input_dims.size() != 3 && input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 or 4 dimensions, got ",
input_dims.size());
}
// Check position_ids
Expand Down Expand Up @@ -63,6 +64,14 @@ Status CheckInputs(const T* input,
int batch_size = static_cast<int>(input_dims[0]);
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);

bool transposed = false;
if (input_dims.size() == 4) {
// input is [batch, num_heads, seq, head_size]
sequence_length = static_cast<int>(input_dims[2]);
hidden_size = static_cast<int>(input_dims[1]) * static_cast<int>(input_dims[3]);
transposed = true;
}
int max_sequence_length = static_cast<int>(cos_cache_dims[0]);
int head_size = static_cast<int>(cos_cache_dims[1]) * 2;
int num_heads = hidden_size / head_size;
Expand Down Expand Up @@ -111,11 +120,12 @@ Status CheckInputs(const T* input,
output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
}

return Status::OK();
}

} // namespace rotary_embedding_helper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
device_prop.maxThreadsPerBlock);
device_prop.maxThreadsPerBlock,
parameters.transposed);

return Status::OK();
}
Expand Down
35 changes: 26 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int num_heads,
const int head_size,
const int position_ids_format,
const bool interleaved) {
const bool interleaved,
const int batch_stride,
const int seq_stride,
const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently

Expand All @@ -37,11 +40,10 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH

const int i = threadIdx.x;

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input + data_offset;
T* output_data = output + data_offset;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;

// Cache is (M, H/2)
const int half_head_size = head_size / 2;
Expand Down Expand Up @@ -83,7 +85,8 @@ Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block) {
const int max_threads_per_block,
const bool transposed) {

constexpr int smem_size = 0;
const dim3 grid(num_heads, sequence_length, batch_size);
Expand All @@ -94,10 +97,22 @@ Status LaunchRotaryEmbeddingKernel(
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.

// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (transposed) {
// When transposed, input tensor shape is [batch, num_heads, seq, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}

assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, smem_size, stream>>>(
output, input, cos_cache, sin_cache, position_ids,
sequence_length, num_heads, head_size, position_ids_format, interleaved
sequence_length, num_heads, head_size, position_ids_format, interleaved,
batch_stride, seq_stride, head_stride
);

return CUDA_CALL(cudaGetLastError());
Expand All @@ -117,7 +132,8 @@ template Status LaunchRotaryEmbeddingKernel<float>(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block);
const int max_threads_per_block,
const bool transposed);

template Status LaunchRotaryEmbeddingKernel<half>(
cudaStream_t stream,
Expand All @@ -133,7 +149,8 @@ template Status LaunchRotaryEmbeddingKernel<half>(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block);
const int max_threads_per_block,
const bool transposed);


} // namespace cuda
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Status LaunchRotaryEmbeddingKernel(
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block);
const int max_threads_per_block,
const bool transposed);

} // namespace cuda
} // namespace contrib
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size)",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
"T")
.Input(1,
"position_ids",
Expand All @@ -1160,7 +1160,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"T")
.Output(0,
"output",
"3D tensor with shape (batch_size, sequence_length, hidden_size)",
"tensor with same shape as input.",
"T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,15 @@ def get_eps(self):
eps = ["CPUExecutionProvider", "CUDAExecutionProvider"]
return list(filter(lambda ep: ep in ort.get_available_providers(), eps))

def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh):
def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh, transposed=False):
eps = self.get_eps()
for ep in eps:
sess = ort.InferenceSession(onnx_graph, providers=[ep])
output_ort = sess.run(None, inputs_ort)[0]
output_ort = output_ort.reshape(
(self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size)
)
if not transposed:
output_ort = output_ort.reshape(
(self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size)
)

# Compare outputs as BxSxNxH
self.assertTrue(np.allclose(expected_output_bsnh, output_ort))
Expand Down Expand Up @@ -445,6 +446,44 @@ def test_hf_token_rotary_one_pos_id(self):
# Compare outputs as BxSxNxH
self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy())

# Bonus test: Prompt step, interleaved = false, pos ids shape = (1), transposed
def test_hf_prompt_rotary_one_pos_id_transposed(self):
x_bnsh = torch.randn(
self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size
)
cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)])
output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH

cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
pos_ms = torch.tensor([0])
onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
inputs_ort = {
"input": x_bnsh.detach().cpu().numpy(),
"position_ids": pos_ms.detach().cpu().numpy(),
}

# Compare outputs as BxNxSxH
self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True)

# Bonus test: Token generation step, interleaved = false, pos ids shape = (1), transposed
def test_hf_token_rotary_one_pos_id_transposed(self):
x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size)
cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length)
pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)])
output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxSxNxH

cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache()
pos_ms = torch.tensor([2])
onnx_graph = self.create_onnx_graph(x_bnsh.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False)
inputs_ort = {
"input": x_bnsh.detach().cpu().numpy(),
"position_ids": pos_ms.detach().cpu().numpy(),
}

# Set tranposed=True to compare outputs as BxSxNxH
self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.detach().cpu().numpy(), transposed=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit dd9933d

Please sign in to comment.