Skip to content

Commit

Permalink
fix integer overflow in Attention (#20921)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
offset used in attention is with data type int. It can overflow for
large sequence length.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
yufenglee authored and baijumeswani committed Jun 20, 2024
1 parent e6252bb commit 1cb6839
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 108 deletions.
106 changes: 52 additions & 54 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ class AttentionCPUBase : public AttentionBase {
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));

const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data<int32_t>() : nullptr;
gsl::span<const int64_t> mask_index_dims = mask_index != nullptr
? mask_index->Shape().GetDims()
: gsl::span<const int64_t>{};
gsl::span<const int64_t> mask_index_dims =
mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span<const int64_t>{};
const T* past_data = past != nullptr ? past->Data<T>() : nullptr;
T* present_data = present != nullptr ? present->MutableData<T>() : nullptr;
const T* past_key_data = past_key != nullptr ? past_key->Data<T>() : nullptr;
Expand All @@ -84,22 +83,19 @@ class AttentionCPUBase : public AttentionBase {
relative_position_bias_data = relative_position_bias->Data<T>();
}

ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
mask_index_data, mask_index_dims, static_cast<T*>(mask_data), causal,
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
present_data, present_key_data, tp, relative_position_bias_data);
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K, mask_index_data, mask_index_dims,
static_cast<T*>(mask_data), causal, batch_size, sequence_length, kv_sequence_length,
past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data,
past_key_data, present_data, present_key_data, tp, relative_position_bias_data);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));

ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data),
static_cast<T*>(attention_probs), V,
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
v_head_size, v_hidden_size, past_data, past_value_data,
present_data, present_value_data, tp);
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size,
v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp);

return Status::OK();
}
Expand Down Expand Up @@ -138,16 +134,17 @@ class AttentionCPUBase : public AttentionBase {
{
// mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxT
if (mask_data != nullptr) {
PrepareMask(mask_index, mask_index_dims, mask_data,
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length,
mask_filter_value_);
}

const int loop_len = batch_size * num_heads_;
const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

TensorOpCost unit_cost;
const size_t probs_matrix_bytes = SafeInt<size_t>(sequence_length) * total_sequence_length * sizeof(T);
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * head_size * total_sequence_length);
const ptrdiff_t probs_matrix_bytes = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * sizeof(T);
unit_cost.compute_cycles =
static_cast<double>(SafeInt<ptrdiff_t>(2) * sequence_length * head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(probs_matrix_bytes);

Expand All @@ -172,15 +169,13 @@ class AttentionCPUBase : public AttentionBase {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast<int>(i) / num_heads_;

const int output_offset = static_cast<int>(i) * sequence_length * total_sequence_length;
const int mask_offset = batch_index * sequence_length * total_sequence_length;
const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * total_sequence_length;
const ptrdiff_t mask_offset = SafeInt<ptrdiff_t>(batch_index) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;

// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
if (mask_data != nullptr) {
memcpy(output,
mask_data + mask_offset,
probs_matrix_bytes);
memcpy(output, mask_data + mask_offset, probs_matrix_bytes);
}

const T* k = K + kv_input_chunk_length * i;
Expand All @@ -197,8 +192,8 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f,
output, nullptr);
Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output,
nullptr);

if (relative_position_bias_data != nullptr) {
for (int j = 0; j < sequence_length * total_sequence_length; j++) {
Expand Down Expand Up @@ -249,8 +244,10 @@ class AttentionCPUBase : public AttentionBase {

// The cost of Gemm
TensorOpCost unit_cost;
unit_cost.compute_cycles = static_cast<double>(2 * sequence_length * v_head_size * total_sequence_length);
unit_cost.bytes_loaded = static_cast<double>((sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.compute_cycles =
static_cast<double>(SafeInt<ptrdiff_t>(2) * sequence_length * v_head_size * total_sequence_length);
unit_cost.bytes_loaded =
static_cast<double>(SafeInt<ptrdiff_t>(sequence_length + v_head_size) * total_sequence_length * sizeof(T));
unit_cost.bytes_stored = static_cast<double>(sequence_length * v_head_size * sizeof(T));

if (present || present_value) {
Expand All @@ -264,35 +261,36 @@ class AttentionCPUBase : public AttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

ThreadPool::TryParallelFor(tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
}
ThreadPool::TryParallelFor(
tp, SafeInt<ptrdiff_t>(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const T* v = V + kv_input_chunk_length * i;
if (nullptr != present) {
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
}

T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * i;
math::MatMul<T>(sequence_length, v_head_size, total_sequence_length,
attention_probs + attention_probs_offset,
v, current_tmp_data, nullptr);

// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
T* src = current_tmp_data;
ptrdiff_t dest_offset = (SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
T* dest = output + dest_offset;
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy_trans);
src += v_head_size;
dest += v_hidden_size;
}
}
});
T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * total_sequence_length * i;
math::MatMul<T>(sequence_length, v_head_size, total_sequence_length,
attention_probs + attention_probs_offset, v, current_tmp_data, nullptr);

// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
const int batch_index = static_cast<int>(i / num_heads_);
const int head_index = static_cast<int>(i % num_heads_);
T* src = current_tmp_data;
ptrdiff_t dest_offset =
(SafeInt<ptrdiff_t>(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size;
T* dest = output + dest_offset;
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy_trans);
src += v_head_size;
dest += v_hidden_size;
}
}
});
}
};

Expand Down
Loading

0 comments on commit 1cb6839

Please sign in to comment.