Skip to content

Commit

Permalink
remove debug print
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Jun 13, 2024
1 parent 48092ee commit de2f30a
Showing 1 changed file with 0 additions and 26 deletions.
26 changes: 0 additions & 26 deletions onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ namespace onnxruntime {
namespace contrib {
namespace rocm {

void print(const std::string& msg, const longlong4& s) {
std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl;
}

void print(const std::string& msg, const int4& s) {
std::cout << msg << ":" << s.x << "," << s.y << "," << s.z << "," << s.w << std::endl;
}

void print(const std::string& msg, const Strides& s) {
print(msg, s.strides_for_bnsh_coord);
}

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
Expand Down Expand Up @@ -263,9 +251,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
const size_t value_offset = static_cast<size_t>(kv_num_heads * head_size);
key_ptr = query_ptr + key_offset;
value_ptr = key_ptr + value_offset;

print("!is_packed_qkv, key_shape ", kv_shape);
print("!is_packed_qkv, key_strides", key_strides);
}

IAllocatorUniquePtr<HipT> rotary_q_tmp;
Expand All @@ -276,9 +261,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size);
auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size);

print("do_rotary, rotary_k_shape ", int4{batch_size, sequence_length, kv_num_heads, head_size});
print("do_rotary, rotary_k_strides", rotary_k_strides);

rotary_q_tmp = GetScratchBuffer<HipT>(q_size, ctx->GetComputeStream());
rotary_k_tmp = GetScratchBuffer<HipT>(k_size, ctx->GetComputeStream());
auto rotary_position_ids_tmp = GetScratchBuffer<int64_t>(sequence_length * batch_size, ctx->GetComputeStream());
Expand Down Expand Up @@ -313,8 +295,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
key_ptr = reinterpret_cast<const HipT*>(rotary_k_tmp.get());
query_strides = rotary_q_strides;
key_strides = rotary_k_strides;
print("do_rotary, key_shape ", kv_shape);
print("do_rotary, key_strides", key_strides);
}

const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast<const int*>(seqlens_k->DataRaw()) : nullptr;
Expand All @@ -324,21 +304,16 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
auto* present_key_ptr = reinterpret_cast<HipT*>(present_key->MutableDataRaw());
auto* present_value_ptr = reinterpret_cast<HipT*>(present_value->MutableDataRaw());
if (parameters.is_prompt) {
std::cout << "is_prompt" << std::endl;
// copy prompt kv to present kv
ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(),
present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk));
ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(),
present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk));
print("is_prompt, key_shape ", kv_shape);
print("is_prompt, key_strides", key_strides);
} else {
std::cout << "!is_prompt" << std::endl;
const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast<const HipT*>(past_key->DataRaw());
const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast<const HipT*>(past_value->DataRaw());
parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME:
if (!parameters.kv_share_buffer) {
std::cout << "!kv_share_buffer" << std::endl;
// copy past to present,
// NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are
// not the same, aka, can not be as simple as strided
Expand All @@ -347,7 +322,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* ctx) const {
ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(),
present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk));
} else {
std::cout << "kv_share_buffer" << std::endl;
// In the case of share buffer
ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr);
ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr);
Expand Down

0 comments on commit de2f30a

Please sign in to comment.