diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index 7be8208484..519669be9e 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -13,6 +13,27 @@ struct half4 { half w; }; +struct half8 { + half x; + half y; + half z; + half w; + half a; + half b; + half c; + half d; +}; +struct float8 { + float x; + float y; + float z; + float w; + float a; + float b; + float c; + float d; +}; + ////////////////data type/////////////// template struct VEC_K {}; @@ -69,6 +90,10 @@ template <> struct Vec_fp32_ { using Type = float4; }; +template <> +struct Vec_fp32_ { + using Type = float8; +}; template struct VEC_V {}; @@ -78,7 +103,7 @@ struct VEC_V { }; template <> struct VEC_V { - using Type = half4; + using Type = half8; }; ////////////////data structures half/////////////// @@ -181,8 +206,17 @@ inline __device__ float4 fma(float a, float4 b, float4 c) { return d; } -inline __device__ float4 fma(half a, float4 b, float4 c) { - assert(false); +inline __device__ float8 fma(float a, float8 f1, float8 f2) { + float8 res; + res.x = fma(a, f1.x, f2.x); + res.y = fma(a, f1.y, f2.y); + res.z = fma(a, f1.z, f2.z); + res.w = fma(a, f1.w, f2.w); + res.a = fma(a, f1.a, f2.a); + res.b = fma(a, f1.b, f2.b); + res.c = fma(a, f1.c, f2.c); + res.d = fma(a, f1.d, f2.d); + return res; } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -211,6 +245,19 @@ inline __device__ float4 add(float4 a, float4 b) { return c; } +inline __device__ float8 add(float8 f1, float8 f2) { + float8 res; + res.x = add(f1.x, f2.x); + res.y = add(f1.y, f2.y); + res.z = add(f1.z, f2.z); + res.w = add(f1.w, f2.w); + res.a = add(f1.a, f2.a); + res.b = add(f1.b, f2.b); + res.c = add(f1.c, f2.c); + res.d = add(f1.d, f2.d); + return res; +} + inline __device__ float sum(float v) { return v; } @@ -271,6 +318,18 @@ inline __device__ float4 cast_to_float(half4 u) { tmp.w = __half2float(u.w); return tmp; } +inline __device__ float8 cast_to_float(half8 u) { + float8 tmp; + tmp.x = __half2float(u.x); + tmp.y = __half2float(u.y); + tmp.z = __half2float(u.z); + tmp.w = __half2float(u.w); + tmp.a = __half2float(u.a); + tmp.b = __half2float(u.b); + tmp.c = __half2float(u.c); + tmp.d = __half2float(u.d); + return tmp; +} inline __device__ void convert_from_float(float4 &dst, float4 src) { dst = src; @@ -281,6 +340,9 @@ inline __device__ void convert_from_float(float &dst, float src) { inline __device__ void convert_from_float(float2 &dst, float2 src) { dst = src; } +inline __device__ void convert_from_float(float8 &dst, float8 src) { + dst = src; +} inline __device__ void convert_from_float(half4 &dst, float4 src) { dst.x = __float2half(src.x); @@ -288,6 +350,17 @@ inline __device__ void convert_from_float(half4 &dst, float4 src) { dst.z = __float2half(src.z); dst.w = __float2half(src.w); } + +inline __device__ void convert_from_float(half8 &dst, float8 src) { + dst.x = __float2half(src.x); + dst.y = __float2half(src.y); + dst.z = __float2half(src.z); + dst.w = __float2half(src.w); + dst.a = __float2half(src.a); + dst.b = __float2half(src.b); + dst.c = __float2half(src.c); + dst.d = __float2half(src.d); +} inline __device__ void convert_from_float(half2 &dst, float2 src) { dst.x = __float2half(src.x); dst.y = __float2half(src.y); @@ -317,7 +390,8 @@ inline __device__ float qk_dot_(K_vec const (&q)[N], K_vec const (&k)[N]) { // use float32 to get better accuracy using Vec_sum = typename Vec_fp32_::Type; // Compute the parallel products for Q*K^T (treat vector lanes separately). - Vec_sum qk_vec = mul(cast_to_float(q[0]), cast_to_float(k[0])); + Vec_sum qk_vec = + mul(cast_to_float(q[0]), cast_to_float(k[0])); #pragma unroll for (int ii = 1; ii < N; ++ii) { qk_vec = FlexFlow::fma(cast_to_float(q[ii]), cast_to_float(k[ii]), qk_vec); @@ -375,7 +449,6 @@ inline __device__ float block_sum(float *red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// utils template inline size_t smem_size_in_bytes(int hidden_size_per_head, int max_sequence_length, @@ -384,69 +457,69 @@ inline size_t smem_size_in_bytes(int hidden_size_per_head, // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_sequence_length + 1, 4) * 16; - // The extra memory needed if we are not using floats for the final logits. - - // store the extra memory if half percision size_t logits_sz = qk_sz; - // if (sizeof(DT) != 4) { - // logits_sz = div_up(max_sequence_length + 1, 4) * 4 * sizeof(DT); - // } // The total size needed during softmax. size_t softmax_sz = qk_sz + logits_sz; + size_t q_size = hidden_size_per_head * sizeof(DT); // The number of partial rows to reduce in the final reduction. int rows_per_red = threads_per_block / threads_per_value; // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(DT) / 2; - + size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2; // The max. - return max(softmax_sz, red_sz); + return max(softmax_sz, red_sz) + q_size; } template inline void smem_size_in_bytes_tree(int hidden_size_per_head, + int max_sequence_length, int threads_per_value, int threads_per_block, TreeVerifyBatchConfig const *bc, int shared_mem[]) { int max_query_length = 0; - int max_total_length = 0; + // int max_total_length = 0; for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } max_query_length = max(max_query_length, bc->requestsInfo[i].num_tokens_in_batch); - max_total_length = max(max_total_length, - bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch); + // max_total_length = max(max_total_length, + // bc->requestsInfo[i].first_token_depth_in_request + + // bc->requestsInfo[i].num_tokens_in_batch); } - int max_qk_length = max_query_length * max_total_length; + // todo fix this + int max_qk_length = 1200; + // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; - // The extra memory needed if we are not using floats for the final logits. - // store the extra memory if half percision size_t logits_sz = qk_sz; - // if (sizeof(DT) != 4) { - // logits_sz = div_up(max_qk_length + 1, 4) * 4 * sizeof(DT); - // } // The total size needed during softmax. size_t softmax_sz = qk_sz + logits_sz; + size_t q_size = hidden_size_per_head * sizeof(DT); + // The number of partial rows to reduce in the final reduction. int rows_per_red = threads_per_block / threads_per_value; // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(DT) / 2; + // use 4 + size_t red_sz = rows_per_red * hidden_size_per_head * sizeof(float) / 2; // The max. shared_mem[0] = qk_sz; - shared_mem[1] = max(softmax_sz, red_sz); + shared_mem[1] = max(softmax_sz, red_sz) + q_size; } +template +struct threads_per_value_t { + static int const value = Dh * sizeof(T) / 16; +}; + } // namespace FlexFlow #endif // _FLEXFLOW_OPS_KERNELS_INC_MULTIHEAD_SELF_UTILS_H \ No newline at end of file diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index cc48c85907..532c7c9556 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -174,7 +174,6 @@ __global__ void compute_attention_kernel_generation_kernel( } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; - ; // if (blockIdx.y == 0 && blockIdx.x == 0) { // printf("qk projkkkhead1 %.10f, %d %d\n", qk, tlength, ti); // } @@ -259,7 +258,6 @@ __global__ void compute_attention_kernel_generation_kernel( // printf("softmax %.10f\n", qk_smem[8]); // printf("softmax %.10f\n", qk_smem[9]); // printf("softmax %.10f\n", qk_smem[10]); - // printf("softmax %.10f\n", qk_smem[11]); // } // value projection @@ -325,13 +323,6 @@ __global__ void compute_attention_kernel_generation_kernel( } } - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("V value10 %.10f\n", out.x); - // printf("V value10 %.10f\n", out.y); - // printf("V value10 %.10f\n", out.z); - // printf("V value10 %.10f\n", out.w); - // } - // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { convert_from_float( @@ -721,17 +712,17 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, } #define LAUNCH_ATTENTION_SCORE_KERNEL( \ - DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_sz = smem_size_in_bytes
(m->qProjSize, \ BatchConfig::max_sequence_length(), \ - THDS_PER_VALUE, \ + THREADS_PER_VALUE, \ THDS_PER_BLOCK); \ compute_attention_kernel_generation_kernel \ + THREADS_PER_VALUE> \ <<>>( \ static_cast
(m->devQKVProjArray), \ static_cast
(m->keyCache), \ @@ -754,24 +745,17 @@ void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; size_t smem_sz; - switch (per_head_size) { - case 64: - LAUNCH_ATTENTION_SCORE_KERNEL(DT, 64, 64, 4, 16, 128, stream); - break; - case 128: - LAUNCH_ATTENTION_SCORE_KERNEL(DT, 128, 128, 4, 32, 128, stream); - break; - default: - assert(false); + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); } - - // // check for errors - // cudaError_t error = cudaGetLastError(); - // if (error != cudaSuccess) { - - // fprintf(stderr, "ERROR: %s \n", cudaGetErrorString(error)); - // assert(false); - // } } template @@ -845,15 +829,15 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, cudaStream_t stream) { // here because we need position info in inference 1 - cudaEvent_t t_start, t_end1, t_end2, t_end3, t_end4, t_end5, t_end6; - cudaEventCreate(&t_start); - cudaEventCreate(&t_end1); - cudaEventCreate(&t_end2); - cudaEventCreate(&t_end3); - cudaEventCreate(&t_end4); - cudaEventCreate(&t_end5); - cudaEventCreate(&t_end6); - cudaEventRecord(t_start, stream); + // cudaEvent_t t_start, t_end1, t_end2, t_end3, t_end4, t_end5, t_end6; + // cudaEventCreate(&t_start); + // cudaEventCreate(&t_end1); + // cudaEventCreate(&t_end2); + // cudaEventCreate(&t_end3); + // cudaEventCreate(&t_end4); + // cudaEventCreate(&t_end5); + // cudaEventCreate(&t_end6); + // cudaEventRecord(t_start, stream); if (m->offload && m->biasSize > 0) { cudaMemcpyAsync( @@ -873,12 +857,12 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, sizeof(BatchConfig::PerRequestInfo), cudaMemcpyHostToDevice, stream); - float elapsed4 = 0; - cudaEventRecord(t_end4, stream); - checkCUDA(cudaEventSynchronize(t_end4)); - checkCUDA(cudaEventElapsedTime(&elapsed4, t_start, t_end4)); - printf("IncMultiHeadSelfAttention copy element kernel time = %.9fms\n", - elapsed4); + // float elapsed4 = 0; + // cudaEventRecord(t_end4, stream); + // checkCUDA(cudaEventSynchronize(t_end4)); + // checkCUDA(cudaEventElapsedTime(&elapsed4, t_start, t_end4)); + // printf("IncMultiHeadSelfAttention copy element kernel time = %.9fms\n", + // elapsed4); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, @@ -889,61 +873,69 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - float elapsed1 = 0; - cudaEventRecord(t_end1, stream); - checkCUDA(cudaEventSynchronize(t_end1)); - checkCUDA(cudaEventElapsedTime(&elapsed1, t_start, t_end1)); - printf("IncMultiHeadSelfAttention qkv kernel time = %.9fms\n", elapsed1); + // float elapsed1 = 0; + // cudaEventRecord(t_end1, stream); + // checkCUDA(cudaEventSynchronize(t_end1)); + // checkCUDA(cudaEventElapsedTime(&elapsed1, t_start, t_end1)); + // printf("IncMultiHeadSelfAttention qkv kernel time = %.9fms\n", elapsed1); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); - float elapsed2 = 0; - cudaEventRecord(t_end2, stream); - checkCUDA(cudaEventSynchronize(t_end2)); - checkCUDA(cudaEventElapsedTime(&elapsed2, t_start, t_end2)); - printf("IncMultiHeadSelfAttention update kv cache time = %.9fms\n", elapsed2); + // float elapsed2 = 0; + // cudaEventRecord(t_end2, stream); + // checkCUDA(cudaEventSynchronize(t_end2)); + // checkCUDA(cudaEventElapsedTime(&elapsed2, t_start, t_end2)); + // printf("IncMultiHeadSelfAttention update kv cache time = %.9fms\n", + // elapsed2); printf("num of generation tokens: %d\n", + // bc->num_generation_tokens); if (bc->num_generation_tokens > 0) { // phase 3: Compute attention score for generation tokens compute_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); } - float elapsed3 = 0; - cudaEventRecord(t_end3, stream); - checkCUDA(cudaEventSynchronize(t_end3)); - checkCUDA(cudaEventElapsedTime(&elapsed3, t_start, t_end3)); - printf("IncMultiHeadSelfAttention attention score time = %.9fms\n", elapsed3); + + // float elapsed3 = 0; + // cudaEventRecord(t_end3, stream); + // checkCUDA(cudaEventSynchronize(t_end3)); + // checkCUDA(cudaEventElapsedTime(&elapsed3, t_start, t_end3)); + // printf("IncMultiHeadSelfAttention attention score time = %.9fms\n", + // elapsed3); if (bc->num_tokens > bc->num_generation_tokens) { // phase 4: Compute attention score for prompt tokens; compute_attention_kernel_prompt( m, bc, shard_id, bias_ptr, weight_ptr, stream); } - float elapsed5 = 0; - cudaEventRecord(t_end5, stream); - checkCUDA(cudaEventSynchronize(t_end5)); - checkCUDA(cudaEventElapsedTime(&elapsed5, t_start, t_end5)); - printf("IncMultiHeadSelfAttention is there a thing? time = %.9fms\n", - elapsed5); + // float elapsed5 = 0; + // cudaEventRecord(t_end5, stream); + // checkCUDA(cudaEventSynchronize(t_end5)); + // checkCUDA(cudaEventElapsedTime(&elapsed5, t_start, t_end5)); + // printf("IncMultiHeadSelfAttention is there a thing? time = %.9fms\n", + // elapsed5); // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); - float elapsed6 = 0; - cudaEventRecord(t_end6, stream); - checkCUDA(cudaEventSynchronize(t_end6)); - checkCUDA(cudaEventElapsedTime(&elapsed6, t_start, t_end6)); - printf("IncMultiHeadSelfAttention final projection time = %.9fms\n", - elapsed6); - - cudaEventDestroy(t_start); - cudaEventDestroy(t_end1); - cudaEventDestroy(t_end2); - cudaEventDestroy(t_end3); - cudaEventDestroy(t_end4); - cudaEventDestroy(t_end5); - cudaEventDestroy(t_end6); + // float elapsed6 = 0; + // cudaEventRecord(t_end6, stream); + // checkCUDA(cudaEventSynchronize(t_end6)); + // checkCUDA(cudaEventElapsedTime(&elapsed6, t_start, t_end6)); + // printf("IncMultiHeadSelfAttention final projection time = %.9fms\n", + // elapsed6); + + // cudaEventDestroy(t_start); + // cudaEventDestroy(t_end1); + // cudaEventDestroy(t_end2); + // cudaEventDestroy(t_end3); + // cudaEventDestroy(t_end4); + // cudaEventDestroy(t_end5); + // cudaEventDestroy(t_end6); + + // if(bc->num_active_tokens() == 1){ + // assert(false); + // } } } // namespace IncMultiHeadAttention diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 7fb529045c..aa6ea1837d 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -81,10 +81,6 @@ __global__ void compute_attention_kernel_fused_kernel( request_infos[request_idx].num_tokens_in_batch; int const qlength = request_infos[request_idx].num_tokens_in_batch; - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("tree metadata %d, %d\n", qlength, tlength); - // } - int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { first_token_idx += request_infos[request_idx].num_tokens_in_batch; @@ -201,9 +197,11 @@ __global__ void compute_attention_kernel_fused_kernel( // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { qk_smem[ti * qlength + qi] *= inv_sum; } + __syncthreads(); } @@ -713,8 +711,12 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, #define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ - smem_size_in_bytes_tree
( \ - m->qProjSize, THDS_PER_VALUE, THDS_PER_BLOCK, bc, smem_sz); \ + smem_size_in_bytes_tree
(m->qProjSize, \ + BatchConfig::max_sequence_length(), \ + THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ + bc, \ + smem_sz); \ compute_attention_kernel_fused_kernelqk production size, 1->total shared size int smem_sz[2]; - switch (per_head_size) { - case 64: - LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL(DT, 64, 64, 4, 16, 128, stream); - break; - case 128: - LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( - DT, 128, 128, 4, 32, 128, stream); - break; - default: - assert(false); - } - // print_tensor((float *)m->attn_heads, 32, "qkv"); - - // check for errors - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) { - - fprintf(stderr, "ERROR: %s \n", cudaGetErrorString(error)); - assert(false); + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); } } @@ -814,6 +807,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, bias_ptr = static_cast
(m->bias_ptr); } } + // copy committed tokens info to GPU for the commit_tokens kernel // Note that m->num_active_tokens stores the number of active // tokens in the previous batch, which is needed for committing @@ -875,13 +869,6 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, bias_ptr, processed_tokens_in_batch, stream); - // if(bc->num_active_tokens() == 5){ - // print_tensor((float *)output_ptr, 32, "output"); - // } - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - // compute_attention_kernel( - // m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } } // namespace TreeIncMultiHeadAttention @@ -901,10 +888,14 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( bool use_bias = *m->qkv_bias || *m->final_bias; cudaEvent_t t_start, t_end; - cudaEventCreate(&t_start); - cudaEventCreate(&t_end); - cudaEventRecord(t_start, stream); + // cudaEventCreate(&t_start); + // cudaEventCreate(&t_end); + // cudaEventRecord(t_start, stream); if (m->profiling) { + // cudaEvent_t t_start, t_end; + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); cudaEventCreate(&t_start); cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); @@ -964,10 +955,10 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } - cudaEventRecord(t_end, stream); - checkCUDA(cudaEventSynchronize(t_end)); - cudaEventDestroy(t_start); - cudaEventDestroy(t_end); + // cudaEventRecord(t_end, stream); + // checkCUDA(cudaEventSynchronize(t_end)); + // cudaEventDestroy(t_start); + // cudaEventDestroy(t_end); // if(bc->num_active_tokens() == 5){