diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index b483e6de81a47..cf71b6bcf7c7d 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -11,6 +11,8 @@ set(contrib_ops_excluded_files "bert/attention_softmax.h" "bert/attention_softmax.cu" "bert/attention_prepare_qkv.cu" + "bert/decoder_attention_impl.h" + "bert/decoder_attention_impl.cu" "bert/decoder_masked_multihead_attention.h" "bert/decoder_masked_multihead_attention.cc" "bert/decoder_masked_self_attention.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 366d8fee1473b..b4a4ae208ceb1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -153,408 +153,285 @@ size_t GetAttentionWorkspaceSize( } template -__global__ void AddBiasTransAppendKvToPresentSmall( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int N = blockDim.y; - const int S = gridDim.x; - const int B = gridDim.y; - - constexpr int M = 3; // Matrix count in qkv - const int m = blockIdx.z + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + assert(data.qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, + sequence_length, stream, + data.scratch); -template -__global__ void AddBiasTransAppendKvToPresent( - const T* qkv, const T* biases, T* present, - const int head_size, const int past_sequence_length, const int max_sequence_length) { - // Input: BxSxMxNxH (Format 1) - // Output: (2, B, N, [P..P+S) of MaxS, H), - // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size - const int n = blockIdx.x; - const int s = blockIdx.y; - const int b = (blockIdx.z >> 1); - const int N = gridDim.x; - const int S = gridDim.y; - const int B = (gridDim.z >> 1); - - constexpr int M = 3; // Matrix count in qkv - const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 - - const int NH = N * head_size; - const int NHS = NH * S; - - qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); - if (biases) { - biases += (m * NH + n * head_size); - } + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - const int MsH = max_sequence_length * head_size; - const int NMsH = N * MsH; - const int BNMsH = B * NMsH; - present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, parameters.kv_sequence_length, stream, + kv_sequence_offset); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - for (int h = threadIdx.x; h < head_size; h += blockDim.x) { - T bias = (biases ? biases[h] : (T)0.0f); - present[h] = qkv[h] + bias; - } -} + DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); -// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. -// bias is of shape (3, NxH) or nullptr -// append to present of (2, B, N, (P..T) of M, H), -template -Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int past_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const T* biases, - const T* qkv_buffer, - T* present) { - assert(head_size <= (1 << 30)); - - int64_t nh = (int64_t)head_size * num_heads; - if (nh <= max_threads_per_block) { - const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v - const dim3 block(max_threads_per_block / num_heads, num_heads, 1); - - AddBiasTransAppendKvToPresentSmall<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); - } else { - const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v - const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); - AddBiasTransAppendKvToPresent<<>>( - qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = + reinterpret_cast(data.fused_cross_attention_kernel); + + // When there is no bias, we can directly use q and packed kv from inputs. + void const* query = data.q; + void const* packed_kv = data.k; + if (data.value == nullptr && data.bias == nullptr) { + query = data.query; + packed_kv = data.key; } - return CUDA_CALL(cudaGetLastError()); + run_fused_cross_attention( + query, // Q + packed_kv, // packed KV + q_sequence_offset, // cumulated sequence length of Q + kv_sequence_offset, // cumulated sequence length of KV + data.output, // output + cross_attention_kernel, // kernels + batch_size, // batch size + parameters.num_heads, // number of heads + parameters.head_size, // head size of Q/K/V + sequence_length, // sequence length of Q + parameters.kv_sequence_length, // sequence length of KV + stream); + + DUMP_TENSOR("trt cross output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + return Status::OK(); } -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const float* bias, - const float* qkv_buffer, - float* present); - -template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, - const int max_sequence_length, - const int total_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const int max_threads_per_block, - const half* bias, - const half* qkv_buffer, - half* present); +template <> +Status FusedTrtCrossAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused cross attention does not support float tensor"); +} template -Status QkvToContext( - const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, - Stream* ort_stream, +Status FusedTrtSelfAttention( + cudaStream_t stream, contrib::AttentionParameters& parameters, AttentionData& data) { - auto stream = static_cast(ort_stream->GetHandle()); - constexpr size_t element_size = sizeof(T); - const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int total_sequence_length = parameters.total_sequence_length; - const int num_heads = parameters.num_heads; - const int qk_head_size = parameters.head_size; - const int v_head_size = parameters.v_head_size; - const bool past_present_share_buffer = parameters.past_present_share_buffer; - const float mask_filter_value = parameters.mask_filter_value; - void* fused_runner = data.fused_runner; - - // At most one fused kernel is enabled. - assert((int(data.use_flash_attention) + - int(data.use_memory_efficient_attention) + - int(fused_runner != nullptr) + - int(data.fused_cross_attention_kernel != nullptr)) <= 1); - - const int batches = batch_size * num_heads; - - bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - - QkvData qkv; - ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, qkv)); - T* scratch1 = data.has_qkv_workspace ? qkv.after_v : data.workspace; - - int present_size_per_batch_k = 0; - int present_size_per_batch_v = 0; - if (!past_present_share_buffer) { - present_size_per_batch_k = total_sequence_length * qk_head_size; - present_size_per_batch_v = total_sequence_length * v_head_size; - ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, - sequence_length, total_sequence_length, parameters.pass_past_in_kv, - stream, max_threads_per_block, data, qkv)); - - } else { // past_present_share_buffer - assert(qk_head_size == v_head_size); - assert(data.fused_cross_attention_kernel == nullptr); - assert(!use_fused_kernel); - assert(data.gemm_buffer != nullptr); - assert(!data.use_memory_efficient_attention); - assert(!data.use_flash_attention); - assert(data.has_qkv_workspace); - - if (nullptr != data.past_key || nullptr != data.present_key) { - // TODO: support this case. - ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); - } - - if (data.present != data.past) { - // For easy testing. Production should better avoid this path. - int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; - cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - } - - // append last k v to present - ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( - stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, - batch_size, qk_head_size, num_heads, max_threads_per_block, - use_fused_causal ? nullptr : data.bias, // For fused causal, bias has been added to gemm_buffer - data.gemm_buffer, data.present)); + const bool causal = parameters.is_unidirectional; - present_size_per_batch_k = parameters.max_sequence_length * qk_head_size; - present_size_per_batch_v = present_size_per_batch_k; - qkv.k = data.present; - qkv.v = data.present + batches * present_size_per_batch_k; - } + int* sequence_offset = reinterpret_cast(data.scratch); - // Q, K and V are ready now DUMP_TENSOR_INIT(); - - if (data.fused_cross_attention_kernel != nullptr) { - assert(qkv.format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(data.mask_index == nullptr); - - int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - scratch1); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); - - int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, - data.mask_index, batch_size, kv_sequence_length, stream, - kv_sequence_offset); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); - - FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = - reinterpret_cast(data.fused_cross_attention_kernel); - - // When there is no bias, we can directly use q and packed kv from inputs. - void const* query = qkv.q; - void const* packed_kv = qkv.k; - if (data.value == nullptr && data.bias == nullptr) { - query = data.query; - packed_kv = data.key; - } - - run_fused_cross_attention( - query, // Q - packed_kv, // packed KV - q_sequence_offset, // cumulated sequence length of Q - kv_sequence_offset, // cumulated sequence length of KV - data.output, // output - cross_attention_kernel, // kernels - batch_size, // batch size - num_heads, // number of heads - qk_head_size, // head size of Q/K/V - sequence_length, // sequence length of Q - kv_sequence_length, // sequence length of KV - stream); - - DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size); - return Status::OK(); + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } else { + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } + DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); + CUDA_RETURN_IF_ERROR(cudaGetLastError()); - // Run TRT fused attention. - if (use_fused_kernel || use_fused_causal) { - int* sequence_offset = reinterpret_cast(scratch1); - if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { - DUMP_TENSOR_D("mask", reinterpret_cast(data.mask_index), batch_size, sequence_length); - LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); - } else { - sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, - data.mask_index, batch_size, sequence_length, stream, - sequence_offset); - } - DUMP_TENSOR_D("sequence_offset", sequence_offset, 1, (data.mask_index != nullptr ? 2 : 1) * batch_size + 1); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); + FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(data.fused_runner); - FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); + const int S = causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); - const int S = use_fused_causal ? sequence_length : fused_fp16_runner->getSFromMaxSeqLen(sequence_length); + // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. + const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); - // B = 2 * batch_size when there is padding in input, and B = batch_size when padding is removed. - const int B = (nullptr == data.mask_index ? batch_size : 2 * batch_size); + fused_fp16_runner->setup(S, B); - fused_fp16_runner->setup(S, B); + if (!causal) { + assert(data.qkv_format == AttentionQkvFormat::QKV_BSN3H); - if (use_fused_kernel) { - assert(qkv.format == AttentionQkvFormat::QKV_BSN3H); - - // When there is no bias, we can directly use packed qkv from inputs. - void const* packed_qkv = qkv.q; - if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { - packed_qkv = data.query; - } - - fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); - } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); - fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = data.q; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; } - return Status::OK(); + + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); + DUMP_TENSOR("fused output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); + } else { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); + DUMP_TENSOR("fused causal output", data.output, + batch_size, sequence_length, parameters.num_heads, parameters.v_head_size); } + return Status::OK(); +} - // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. - const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) - : parameters.scale; +// Template Specialization for float type +template <> +Status FusedTrtSelfAttention( + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, + "Trt fused attention does not support float tensor"); +} #if USE_FLASH_ATTENTION - if (data.use_flash_attention) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - assert(nullptr == data.mask_index); - assert(nullptr == data.relative_position_bias); - assert(parameters.head_size == parameters.v_head_size); - - void* query = reinterpret_cast(qkv.q); - void* key = reinterpret_cast(qkv.k); - void* value = reinterpret_cast(qkv.v); - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { - query = reinterpret_cast(const_cast(data.query)); - } - - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - constexpr bool is_causal = false; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1), - parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, is_causal)); +template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); + + void* query = reinterpret_cast(data.q); + void* key = reinterpret_cast(data.k); + void* value = reinterpret_cast(data.v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } - DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(data.scratch), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional)); + + DUMP_TENSOR("flash attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} - return Status::OK(); - } +template <> +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor"); +} #endif #if USE_MEMORY_EFFICIENT_ATTENTION - if (data.use_memory_efficient_attention) { - // We only enable fused cross attention when there is no key padding mask. - // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - - const void* query = qkv.q; - const void* key = qkv.k; - const void* value = qkv.v; - // For packed KV, we can use query input directly. - if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { - assert(data.bias == nullptr); - query = data.query; - } +template +Status EfficientAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + + const void* query = data.q; + const void* key = data.k; + const void* value = data.v; + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); + query = data.query; + } - DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", qkv.k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", qkv.v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); - - MemoryEfficientAttentionParams p; - p.sm = device_prop.major * 10 + device_prop.minor; - p.is_half = sizeof(T) == 2; - p.batch_size = parameters.batch_size; - p.num_heads = parameters.num_heads; - p.sequence_length = parameters.sequence_length; - p.kv_sequence_length = parameters.total_sequence_length; - p.qk_head_size = parameters.head_size; - p.v_head_size = parameters.v_head_size; - p.causal = parameters.is_unidirectional; - p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("k(BSNH)", data.k, + parameters.batch_size, parameters.total_sequence_length, parameters.num_heads, parameters.head_size); + DUMP_TENSOR_D("v(BSNH)", data.v, + parameters.batch_size, parameters.total_sequence_length, + parameters.num_heads, parameters.v_head_size); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.scale = scale; + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); - p.query = query; - p.key = key; - p.value = value; - p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; - p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; - p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) - ? scratch1 - : nullptr; - p.stream = stream; - run_memory_efficient_attention(p); - DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); - - return Status::OK(); - } + : const_cast(reinterpret_cast( + data.mask_index + parameters.batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast( + data.mask_index + 2 * parameters.batch_size + 1)); + p.query = query; + p.key = key; + p.value = value; + p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; + p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float)) + ? data.scratch + : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + DUMP_TENSOR("efficient attention output", data.output, + parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); + + return Status::OK(); +} #endif - // The following are unfused attention. - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH); +template +Status UnfusedAttention( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data, + float scale) { + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH); + + auto stream = static_cast(ort_stream->GetHandle()); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + const int batches = batch_size * num_heads; + const int* mask_index = data.mask_index; gsl::span& mask_index_dims = data.mask_index_dims; // Raw attention mask could be 2D (BxT) or 3D (BxSxT) or 4D(Bx1xMxM), where M is the max sequence length. bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2); - // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxT + // Compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch: BxNxSxT // Q: BxNxSxH, K (present_k): BxNxTxH, Q*K': BxNxSxT float one = 1.0f; float zero = 0.f; @@ -563,22 +440,31 @@ Status QkvToContext( cublasSetStream(cublas, stream); - DUMP_TENSOR_D("q[BNSH]", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k[BNSH]", k, batch_size, num_heads, total_sequence_length, qk_head_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q[BNSH]", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k[BNSH]", data.k, batch_size, num_heads, total_sequence_length, qk_head_size); + + const int present_sequence_length = parameters.past_present_share_buffer + ? parameters.max_sequence_length + : total_sequence_length; + const int present_size_per_batch_k = present_sequence_length * qk_head_size; + const int present_size_per_batch_v = present_sequence_length * v_head_size; + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, total_sequence_length, sequence_length, qk_head_size, - &alpha, qkv.k, qk_head_size, present_size_per_batch_k, - qkv.q, qk_head_size, sequence_length * qk_head_size, - &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &alpha, data.k, qk_head_size, present_size_per_batch_k, + data.q, qk_head_size, sequence_length * qk_head_size, + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", qkv.q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", qkv.k, batch_size, num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length); + constexpr size_t element_size = sizeof(T); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); - T* scratch2 = scratch1 + (bytes / element_size); + T* scratch2 = data.scratch + (bytes / element_size); // Apply softmax and store result R to scratch2: BxNxSxT if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask @@ -588,14 +474,15 @@ Status QkvToContext( const TransformerOptions* options = TransformerOptions::GetInstance(); bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax(); - T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. + // replace Q*K' in place with masked score for persistent softmax. + T* persistent_softmax_workspace = data.scratch; ORT_RETURN_IF_ERROR( ComputeSoftmaxWithRawMask( ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, + data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension, parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + parameters.mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -603,277 +490,123 @@ Status QkvToContext( ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D( stream, total_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional)); + data.scratch, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( ComputeSoftmax( stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional)); } DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", qkv.v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v - T* temp_output = qkv.q; + T* temp_output = data.q; CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, v_head_size, sequence_length, total_sequence_length, - &one, qkv.v, v_head_size, present_size_per_batch_v, + &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, temp_output, data.output); + device_prop.maxThreadsPerBlock, false, temp_output, data.output); DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } template -Status DecoderQkvToContext( +Status QkvToContext( const cudaDeviceProp& device_prop, - Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const T* gemm_query_buffer, - const T* gemm_kv_buffer, - const bool* key_padding_mask, - const T* key_cache, - const T* value_cache, - T* qkv_buffer, - T* workspace_buffer, - T* output, - T* new_key_cache, - T* new_value_cache) { + Stream* ort_stream, + contrib::AttentionParameters& parameters, + AttentionData& data) { + auto stream = static_cast(ort_stream->GetHandle()); const int max_threads_per_block = device_prop.maxThreadsPerBlock; - const int BN = batch_size * num_heads; - const int BHN = BN * head_size; - const int BNS = BN * sequence_length; - const int k_buffer_offset = sequence_length * BHN; - const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int total_sequence_length = parameters.total_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; - T* temp_qkv_buffer = workspace_buffer; - auto stream = static_cast(ort_stream->GetHandle()); + // At most one fused kernel is enabled. + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); - const T* q = qkv_buffer; - // transpose q and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); - - const T* k = qkv_buffer + k_buffer_offset; - const T* v = qkv_buffer + v_buffer_offset; - if (!has_layer_state || !use_past) { - if (!static_kv) { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } else { - // transpose kv and copy them to qkv_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); - } - } else { - if (!static_kv) { - // transpose kv and copy them to temp_buffer - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); - // concat cache-k with k and copy to qkv_buffer - if (nullptr != key_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - key_cache, - temp_qkv_buffer, - qkv_buffer + k_buffer_offset)); - } - // concat cache-v with v and copy to qkv_buffer - if (nullptr != value_cache) { - ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, - sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, 1, - value_cache, - temp_qkv_buffer + k_buffer_offset, - qkv_buffer + v_buffer_offset)); - } + ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block)); + + if (!parameters.past_present_share_buffer) { + ORT_RETURN_IF_ERROR(ConcatPastToPresent(batch_size, num_heads, qk_head_size, v_head_size, + sequence_length, total_sequence_length, parameters.pass_past_in_kv, + stream, max_threads_per_block, data)); + + } else { // past_present_share_buffer + assert(qk_head_size == v_head_size); + assert(data.fused_cross_attention_kernel == nullptr); + assert(nullptr == fused_runner || parameters.is_unidirectional); + assert(data.gemm_buffer != nullptr); + assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); + assert(data.has_qkv_workspace); + + if (nullptr != data.past_key || nullptr != data.present_key) { + // TODO: support this case. + ORT_THROW("buffer sharing for no bias case between past and present is not supported yet."); } - } - if (has_layer_state) { - if (use_past && static_kv) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - } else { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + if (data.present != data.past) { + // For easy testing. Production should better avoid this path. + int64_t kv_size = 2LL * (int64_t)batch_size * num_heads * parameters.max_sequence_length * qk_head_size; + cudaMemcpyAsync(data.present, data.past, kv_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } - } - // scratch1: BxNxSxL buffer - // scratch2: BxNxSxL buffer - // scratch3: BxNxSxH buffer - T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; - T* scratch2 = scratch1 + BNS * kv_sequence_length; - T* scratch3 = scratch2 + BNS * kv_sequence_length; - - // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL - // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL - const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int temp_matrix_size = sequence_length * kv_sequence_length; - float one = 1.0f; - float zero = 0.f; + // For fused causal, bias has been added to gemm_buffer. + const T* bias = (nullptr != fused_runner && parameters.is_unidirectional) ? nullptr : data.bias; - float alpha = rsqrt_head_size; - const int strideA = kv_sequence_length * head_size; - const int strideB = sequence_length * head_size; - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, key_cache, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, - kv_sequence_length, sequence_length, head_size, - &alpha, k, head_size, strideA, - q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + // append last k v to present + ORT_RETURN_IF_ERROR(LaunchAddBiasTransAppendKvToPresent( + stream, parameters.max_sequence_length, parameters.past_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, + bias, data.gemm_buffer, data.present)); + + data.k = data.present; + data.v = data.present + batch_size * num_heads * parameters.max_sequence_length * qk_head_size; } - constexpr bool is_unidirectional = false; - const T* add_before_softmax = nullptr; - if (has_key_padding_mask) { - constexpr int mask_dimension = 2; - constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( - ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); - } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax( - stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + // Q, K and V are ready now + if (data.fused_cross_attention_kernel != nullptr) { + return FusedTrtCrossAttention(stream, parameters, data); } - // compute P*V (as V*P), and store in scratch3: BxNxSxH - if (use_past && static_kv) { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, value_cache, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); - } else { - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( - cublas, CUBLAS_OP_N, CUBLAS_OP_N, - head_size, sequence_length, kv_sequence_length, - &one, v, head_size, strideA, - scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + // Run TRT fused attention. + if (nullptr != fused_runner) { + return FusedTrtSelfAttention(stream, parameters, data); } - // scratch3 is BxNxSxH, transpose to output SxBxNxH - return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, - max_threads_per_block, true, scratch3, output); -} + // For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation. + const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& device_prop, - Stream* stream, - cublasHandle_t& cublas, - const size_t element_size, - const int batch_size, - const int sequence_length, - const int kv_sequence_length, - const int num_heads, - const int head_size, - const bool static_kv, - const bool use_past, - const bool has_layer_state, - const bool has_key_padding_mask, - const float mask_filter_value, - const void* gemm_query_buffer, - const void* gemm_kv_buffer, - const bool* key_padding_mask, - const void* key_cache, - const void* value_cache, - void* qkv_buffer, - void* workspace_buffer, - void* output, - void* new_key_cache, - void* new_value_cache) { - if (element_size == 2) { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); - } else { - return DecoderQkvToContext( - device_prop, - stream, - cublas, - element_size, - batch_size, - sequence_length, - kv_sequence_length, - num_heads, - head_size, - static_kv, - use_past, - has_layer_state, - has_key_padding_mask, - mask_filter_value, - reinterpret_cast(gemm_query_buffer), - reinterpret_cast(gemm_kv_buffer), - key_padding_mask, - reinterpret_cast(key_cache), - reinterpret_cast(value_cache), - reinterpret_cast(qkv_buffer), - reinterpret_cast(workspace_buffer), - reinterpret_cast(output), - reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); +#if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data, scale); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.use_memory_efficient_attention) { + return EfficientAttention(device_prop, stream, parameters, data, scale); } +#endif + + return UnfusedAttention(device_prop, cublas, ort_stream, parameters, data, scale); } // Template Instantiation diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index c361a47c364d3..d0a5fb51a25d6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -81,24 +81,20 @@ struct AttentionData { mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache = nullptr; mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache = nullptr; -}; -// Intermediate data pointers available after PrepareQKV -template -struct QkvData { + // Intermediate data T* q = nullptr; T* k = nullptr; T* v = nullptr; - T* after_v = nullptr; // pointer right after v - AttentionQkvFormat format = AttentionQkvFormat::Q_K_V_BSNH; + T* scratch = nullptr; + AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; }; template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv_data); + int max_threads_per_block); template Status QkvToContext( @@ -108,33 +104,6 @@ Status QkvToContext( contrib::AttentionParameters& parameters, AttentionData& data); -Status LaunchDecoderAttentionKernel( - const cudaDeviceProp& prop, // Device Properties - Stream* stream, // ORT Stream - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - // BxNxSxH => BxSxNxH or SxBxNxH (reversed_bs is true) Status LaunchTransCtx(cudaStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, @@ -184,14 +153,27 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present); template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) T* out, longlong4 out_strides, // coord (b,n,s,h) int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu similarity index 67% rename from onnxruntime/contrib_ops/cuda/bert/attention_concat.cu rename to onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu index 8378ee2691c6b..89be0f1115f41 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_concat.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_kv_cache.cu @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" using namespace onnxruntime::cuda; @@ -244,48 +245,48 @@ Status LaunchConcatPastToPresent(cudaStream_t stream, present); } -#ifndef USE_ROCM // exclude from hipify +#ifndef USE_ROCM // exclude the following from hipify since they are not used in ROCM EP + template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv) { + AttentionData& data) { // Concat past key value to present (2xBxNxLxH), where L is kv_sequence_length and T is total_sequence_length. // past_k (BxNxPxH) + k (BxNxLxH) => present_k (BxNxTxH) // past_v (BxNxPxH) + v (BxNxLxH) => present_v (BxNxTxH) // When there is past state, the head size for Q/K/V shall be same: H == H_v. if (nullptr != data.present) { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BNSH || qkv.format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH || + data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); + ORT_RETURN_IF_ERROR( LaunchConcatPastToPresent( stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, qkv.k, data.present)); + max_threads_per_block, data.past, data.k, data.present)); // Update pointers to present_k and present_v. - qkv.k = data.present; - qkv.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; - } - - if (nullptr != data.past_key || nullptr != data.present_key) { + data.k = data.present; + data.v = data.present + batch_size * num_heads * total_sequence_length * qk_head_size; + } else if (nullptr != data.past_key || nullptr != data.present_key) { if (nullptr != data.past_key && nullptr == data.present_key) { - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); } else if (nullptr == data.past_key && nullptr != data.present_key) { - if (qkv.format == AttentionQkvFormat::Q_K_V_BNSH) { - qkv.k = data.present_key; - qkv.v = data.present_value; + if (data.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) { + data.k = data.present_key; + data.v = data.present_value; } else { - assert(qkv.format == AttentionQkvFormat::Q_K_V_BSNH); - qkv.k = data.temp_k_workspace; - qkv.v = data.temp_v_workspace; + assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + data.k = data.temp_k_workspace; + data.v = data.temp_v_workspace; } } else if (pass_past_in_kv) { // past_key and past_value are used directly as key and value in attention computations - qkv.k = const_cast(data.past_key); - qkv.v = const_cast(data.past_value); + data.k = const_cast(data.past_key); + data.v = const_cast(data.past_value); // This path has a memory copy from past_key and past_value to present_key and present_value // Avoid this path since the memory copy is unnecessary because past_key == present_key and @@ -298,14 +299,14 @@ Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, 1, data.past_key, qkv.k, data.present_key)); + max_threads_per_block, 1, data.past_key, data.k, data.present_key)); ORT_RETURN_IF_ERROR( LaunchConcatTensorToTensor(stream, total_sequence_length, sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, 1, data.past_value, qkv.v, data.present_value)); + max_threads_per_block, 1, data.past_value, data.v, data.present_value)); // Update pointers to present_k and present_v. - qkv.k = data.present_key; - qkv.v = data.present_value; + data.k = data.present_key; + data.v = data.present_value; } } @@ -317,15 +318,147 @@ template Status ConcatPastToPresent(int batch_size, int num_heads, int qk int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); template Status ConcatPastToPresent(int batch_size, int num_heads, int qk_head_size, int v_head_size, int sequence_length, int total_sequence_length, bool pass_past_in_kv, cudaStream_t stream, int max_threads_per_block, - AttentionData& data, - QkvData& qkv); + AttentionData& data); + +// ---------------------------------------------------------------------------------- +// Below kernels are for past and present sharing buffer +// ---------------------------------------------------------------------------------- + +template +__global__ void AddBiasTransAppendKvToPresentSmall( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + constexpr int M = 3; // Matrix count in qkv + const int m = blockIdx.z + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +template +__global__ void AddBiasTransAppendKvToPresent( + const T* qkv, const T* biases, T* present, + const int head_size, const int past_sequence_length, const int max_sequence_length) { + // Input: BxSxMxNxH (Format 1) + // Output: (2, B, N, [P..P+S) of MaxS, H), + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + const int n = blockIdx.x; + const int s = blockIdx.y; + const int b = (blockIdx.z >> 1); + const int N = gridDim.x; + const int S = gridDim.y; + const int B = (gridDim.z >> 1); + + constexpr int M = 3; // Matrix count in qkv + const int m = (blockIdx.z & 0x1) + 1; // k = 1, v = 2 + + const int NH = N * head_size; + const int NHS = NH * S; + + qkv += (n * head_size + (s * M + m) * NH + b * M * NHS); + if (biases) { + biases += (m * NH + n * head_size); + } + + const int MsH = max_sequence_length * head_size; + const int NMsH = N * MsH; + const int BNMsH = B * NMsH; + present += ((past_sequence_length + s) * head_size + n * MsH + b * NMsH + (m - 1) * BNMsH); + + for (int h = threadIdx.x; h < head_size; h += blockDim.x) { + T bias = (biases ? biases[h] : (T)0.0f); + present[h] = qkv[h] + bias; + } +} + +// qkv buffer is merged tensor of shape (B,S,3,N,H), k v is the second/third of the 3. +// bias is of shape (3, NxH) or nullptr +// append to present of (2, B, N, (P..T) of M, H), +template +Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int past_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const T* biases, + const T* qkv_buffer, + T* present) { + assert(head_size <= (1 << 30)); + + int64_t nh = (int64_t)head_size * num_heads; + if (nh <= max_threads_per_block) { + const dim3 grid(sequence_length, batch_size, 2); // 2 for k and v + const dim3 block(max_threads_per_block / num_heads, num_heads, 1); + + AddBiasTransAppendKvToPresentSmall<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } else { + const dim3 grid(num_heads, sequence_length, batch_size * 2); // 2 for k and v + const dim3 block(std::min(head_size, max_threads_per_block), 1, 1); + AddBiasTransAppendKvToPresent<<>>( + qkv_buffer, biases, present, head_size, past_sequence_length, max_sequence_length); + } + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const float* bias, + const float* qkv_buffer, + float* present); + +template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, + const int max_sequence_length, + const int total_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const int max_threads_per_block, + const half* bias, + const half* qkv_buffer, + half* present); #endif } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index cd4137ab11de9..5c65a30918ece 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" @@ -406,22 +405,25 @@ Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, // Query (BxSxNxH) => Q (BxNxSxH) constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + LaunchAddBiasTranspose( + stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); @@ -435,8 +437,8 @@ template Status PrepareQkv(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv) { + int max_threads_per_block) { + data.scratch = data.workspace; if (data.has_qkv_workspace) { const int size_per_batch_q = parameters.sequence_length * parameters.head_size; const int size_per_batch_k = parameters.kv_sequence_length * parameters.head_size; @@ -445,27 +447,27 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - qkv.q = data.workspace; - qkv.k = data.workspace + elements_q; - qkv.v = qkv.k + elements_k; - qkv.after_v = qkv.v + elements_v; + data.q = data.workspace; + data.k = data.workspace + elements_q; + data.v = data.k + elements_k; + data.scratch = data.v + elements_v; } if (nullptr != data.gemm_buffer) { // Attention operator ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, - qkv.format)); + data.qkv_format)); } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } else { // multihead attention operator, no past, separated Q/K/V inputs ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, - qkv.q, qkv.k, qkv.v, qkv.format)); + data.q, data.k, data.v, data.qkv_format)); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -477,15 +479,13 @@ template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); template Status PrepareQkv( contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, - int max_threads_per_block, - QkvData& qkv); + int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu index e7d2255fb46b8..01ea02f48d3ab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu @@ -18,7 +18,6 @@ limitations under the License. */ #include -#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index f907d300607f9..3f703ae3d05e6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -1,8 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/decoder_attention.h" +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include "core/framework/op_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -85,7 +85,8 @@ Status CheckInputs(const TensorShape& query_shape, } if (kv_weights_dims[0] != hidden_size || kv_weights_dims[1] != 2 * static_cast(hidden_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "kv_weights shall have shape (hidden size, 2 * hidden size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_weights shall have shape (hidden size, 2 * hidden size)"); } const auto& bias_dims = bias_shape.GetDims(); @@ -137,7 +138,8 @@ Status CheckInputs(const TensorShape& query_shape, const auto& value_cache_dims = value_cache->Shape().GetDims(); if (value_cache_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value_cache' is expected to have 4 dimension, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'value_cache' is expected to have 4 dimension, got ", value_cache_dims.size()); } @@ -353,10 +355,12 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { } } - size_t bytes = element_size * batch_size * (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; + size_t bytes = element_size * batch_size * + (static_cast(sequence_length) + static_cast(2) * kv_sequence_length) * hidden_size; auto qkv_buffer_p = GetScratchBuffer(bytes, context->GetComputeStream()); - bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * (static_cast(2) * head_size + static_cast(kv_sequence_length)); + bytes = element_size * 2 * batch_size * sequence_length * num_heads_ * + (static_cast(2) * head_size + static_cast(kv_sequence_length)); auto workspace_p = GetScratchBuffer(bytes, context->GetComputeStream()); Tensor* output(context->Output(0, query_shape)); diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu new file mode 100644 index 0000000000000..1dc22a9c8ea98 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/bert/decoder_attention_impl.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/bert/attention_softmax.h" + +using namespace onnxruntime::contrib::attention_softmax_cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status DecoderQkvToContext( + const cudaDeviceProp& device_prop, + Stream* ort_stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const T* gemm_query_buffer, + const T* gemm_kv_buffer, + const bool* key_padding_mask, + const T* key_cache, + const T* value_cache, + T* qkv_buffer, + T* workspace_buffer, + T* output, + T* new_key_cache, + T* new_value_cache) { + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int BN = batch_size * num_heads; + const int BHN = BN * head_size; + const int BNS = BN * sequence_length; + const int k_buffer_offset = sequence_length * BHN; + const int v_buffer_offset = (sequence_length + kv_sequence_length) * BHN; + + T* temp_qkv_buffer = workspace_buffer; + auto stream = static_cast(ort_stream->GetHandle()); + + const T* q = qkv_buffer; + // transpose q and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_query_buffer, qkv_buffer)); + + const T* k = qkv_buffer + k_buffer_offset; + const T* v = qkv_buffer + v_buffer_offset; + if (!has_layer_state || !use_past) { + if (!static_kv) { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } else { + // transpose kv and copy them to qkv_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, kv_sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, qkv_buffer + k_buffer_offset)); + } + } else { + if (!static_kv) { + // transpose kv and copy them to temp_buffer + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 2, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, gemm_kv_buffer, temp_qkv_buffer)); + // concat cache-k with k and copy to qkv_buffer + if (nullptr != key_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + key_cache, + temp_qkv_buffer, + qkv_buffer + k_buffer_offset)); + } + // concat cache-v with v and copy to qkv_buffer + if (nullptr != value_cache) { + ORT_RETURN_IF_ERROR(LaunchConcatTensorToTensor(stream, kv_sequence_length, + sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, 1, + value_cache, + temp_qkv_buffer + k_buffer_offset, + qkv_buffer + v_buffer_offset)); + } + } + } + + if (has_layer_state) { + if (use_past && static_kv) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + } + } + + // scratch1: BxNxSxL buffer + // scratch2: BxNxSxL buffer + // scratch3: BxNxSxH buffer + T* scratch1 = temp_qkv_buffer + 3 * BHN * sequence_length; + T* scratch2 = scratch1 + BNS * kv_sequence_length; + T* scratch3 = scratch2 + BNS * kv_sequence_length; + + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxL + // Q: BxNxSxH, K (present_k): BxNxLxH, Q*K': BxNxSxL + const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); + const int temp_matrix_size = sequence_length * kv_sequence_length; + float one = 1.0f; + float zero = 0.f; + + float alpha = rsqrt_head_size; + const int strideA = kv_sequence_length * head_size; + const int strideB = sequence_length * head_size; + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, key_cache, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_T, CUBLAS_OP_N, + kv_sequence_length, sequence_length, head_size, + &alpha, k, head_size, strideA, + q, head_size, strideB, + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + } + + constexpr bool is_unidirectional = false; + const T* add_before_softmax = nullptr; + if (has_key_padding_mask) { + constexpr int mask_dimension = 2; + constexpr int max_sequence_length = 0; + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); + } else { + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); + } + + // compute P*V (as V*P), and store in scratch3: BxNxSxH + if (use_past && static_kv) { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, value_cache, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } else { + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + cublas, CUBLAS_OP_N, CUBLAS_OP_N, + head_size, sequence_length, kv_sequence_length, + &one, v, head_size, strideA, + scratch2, kv_sequence_length, temp_matrix_size, + &zero, scratch3, head_size, strideB, BN, device_prop)); + } + + // scratch3 is BxNxSxH, transpose to output SxBxNxH + return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads, + max_threads_per_block, true, scratch3, output); +} + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& device_prop, + Stream* stream, + cublasHandle_t& cublas, + const size_t element_size, + const int batch_size, + const int sequence_length, + const int kv_sequence_length, + const int num_heads, + const int head_size, + const bool static_kv, + const bool use_past, + const bool has_layer_state, + const bool has_key_padding_mask, + const float mask_filter_value, + const void* gemm_query_buffer, + const void* gemm_kv_buffer, + const bool* key_padding_mask, + const void* key_cache, + const void* value_cache, + void* qkv_buffer, + void* workspace_buffer, + void* output, + void* new_key_cache, + void* new_value_cache) { + if (element_size == 2) { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } else { + return DecoderQkvToContext( + device_prop, + stream, + cublas, + element_size, + batch_size, + sequence_length, + kv_sequence_length, + num_heads, + head_size, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, + mask_filter_value, + reinterpret_cast(gemm_query_buffer), + reinterpret_cast(gemm_kv_buffer), + key_padding_mask, + reinterpret_cast(key_cache), + reinterpret_cast(value_cache), + reinterpret_cast(qkv_buffer), + reinterpret_cast(workspace_buffer), + reinterpret_cast(output), + reinterpret_cast(new_key_cache), + reinterpret_cast(new_value_cache)); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..9db9ccb45e330 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cuda/bert/attention_impl.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +Status LaunchDecoderAttentionKernel( + const cudaDeviceProp& prop, // Device Properties + Stream* stream, // ORT Stream + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu index 9a150c9e6cd77..b0ed3ff82226a 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/rocm/bert/attention_impl.h" #include "contrib_ops/rocm/bert/attention_softmax.h" +#include "contrib_ops/rocm/bert/decoder_attention_impl.h" using namespace onnxruntime::rocm; diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 19b2bc34efaec..3164e8c211099 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -28,34 +28,6 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length); -Status LaunchDecoderAttentionKernel( - const hipDeviceProp_t& prop, // Device Properties - RocmTuningContext* tuning_ctx, // context for tuning - Stream* stream, // ORT Stream - rocblas_handle& rocblas, // Rocblas handle - const size_t element_size, // Element size of input tensor - const int batch_size, // Batch size (B) - const int sequence_length, // Sequence length (S) - const int kv_sequence_length, // Key/Value/Cache sequence length - const int num_heads, // Number of attention heads (N) - const int head_size, // Hidden layer size per head (H) - const bool static_kv, // Whether cross attention or not - const bool use_past, // Whether use cache or not - const bool has_layer_state, // Whether output cache or not - const bool has_key_padding_mask, // Whether use key_padding_mask or not - const float mask_filter_value, // Mask filter value - const void* gemm_query_buffer, // Query buffer - const void* gemm_kv_buffer, // Key and value buffer - const bool* key_padding_mask, // Key padding mask - const void* key_cache, // Input key cache - const void* value_cache, // Input value cache - void* qkv_buffer, // Temporary buffer - void* workspace_buffer, // Temporary buffer - void* output, // Output tensor - void* new_key_cache, // New_key_cache tensor - void* new_value_cache // New_value_cache tensor -); - Status LaunchTransCtx(hipStream_t stream, const int sequence_length, const int batch_size, const int head_size, const int num_heads, const int max_threads_per_block, const bool reversed_bs, const float* input, float* output); diff --git a/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h new file mode 100644 index 0000000000000..d71c6d8440a44 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "contrib_ops/cpu/bert/attention_common.h" +#include "core/providers/rocm/shared_inc/rocm_utils.h" +#include "core/providers/rocm/tunable/rocm_tunable.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +Status LaunchDecoderAttentionKernel( + const hipDeviceProp_t& prop, // Device Properties + RocmTuningContext* tuning_ctx, // context for tuning + Stream* stream, // ORT Stream + rocblas_handle& rocblas, // Rocblas handle + const size_t element_size, // Element size of input tensor + const int batch_size, // Batch size (B) + const int sequence_length, // Sequence length (S) + const int kv_sequence_length, // Key/Value/Cache sequence length + const int num_heads, // Number of attention heads (N) + const int head_size, // Hidden layer size per head (H) + const bool static_kv, // Whether cross attention or not + const bool use_past, // Whether use cache or not + const bool has_layer_state, // Whether output cache or not + const bool has_key_padding_mask, // Whether use key_padding_mask or not + const float mask_filter_value, // Mask filter value + const void* gemm_query_buffer, // Query buffer + const void* gemm_kv_buffer, // Key and value buffer + const bool* key_padding_mask, // Key padding mask + const void* key_cache, // Input key cache + const void* value_cache, // Input value cache + void* qkv_buffer, // Temporary buffer + void* workspace_buffer, // Temporary buffer + void* output, // Output tensor + void* new_key_cache, // New_key_cache tensor + void* new_value_cache // New_value_cache tensor +); + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime