diff --git a/examples/inference/python/export/huggingface/ls_hf_quant_gpt2_export.py b/examples/inference/python/export/huggingface/ls_hf_quant_gpt2_export.py index 2ecbb977..2efa5743 100644 --- a/examples/inference/python/export/huggingface/ls_hf_quant_gpt2_export.py +++ b/examples/inference/python/export/huggingface/ls_hf_quant_gpt2_export.py @@ -45,6 +45,7 @@ def extract_gpt_weights( eos_id=50256, pad_id=50257, max_step=50, + extra_decode_length=0, ): # load var names with open(os.path.join(os.path.dirname(model_dir), "config.json")) as f: @@ -121,6 +122,9 @@ def extract_gpt_weights( hdf5_file.create_dataset("model_conf/topp", data=topp, dtype="f4") hdf5_file.create_dataset("model_conf/topk", data=topk, dtype="i4") hdf5_file.create_dataset("model_conf/eos_id", data=eos_id, dtype="i4") + hdf5_file.create_dataset( + "model_conf/extra_decode_length", data=extra_decode_length, dtype="i4" + ) hdf5_file.close() # read-in again to double check @@ -150,6 +154,7 @@ def _print_pair(key, value): eos_id = 50256 pad_id = 50257 max_step = 50 + extra_decode_length = 0 # use positive length to avtivate it extract_gpt_weights( hdf5_path, args.model, @@ -159,4 +164,5 @@ def _print_pair(key, value): eos_id=eos_id, pad_id=pad_id, max_step=max_step, + extra_decode_length=extra_decode_length, ) diff --git a/lightseq/csrc/kernels/includes/cublas_algo_map.h b/lightseq/csrc/kernels/includes/cublas_algo_map.h index 73aed594..3b7a2da4 100644 --- a/lightseq/csrc/kernels/includes/cublas_algo_map.h +++ b/lightseq/csrc/kernels/includes/cublas_algo_map.h @@ -19,7 +19,9 @@ #define STRIDE 32 #define BORDER 512 -static std::string DEFAULT_URL = "https://zenodo.org/record/7219754/files/"; +static std::string DEFAULT_URL = + "http://lf3-nlp-opensource.bytetos.com/obj/nlp-opensource/lightseq/" + "igemm_configs/"; static std::string DEFAULT_DIR = std::string(std::getenv("HOME")) + "/.lightseq/igemm_configs/"; static std::string IGEMM_T4_CONFIG = "igemm_T4.cfg"; diff --git a/lightseq/inference/model/cublas_algo_map.h b/lightseq/inference/model/cublas_algo_map.h index 28aabead..1aa906b0 100644 --- a/lightseq/inference/model/cublas_algo_map.h +++ b/lightseq/inference/model/cublas_algo_map.h @@ -21,7 +21,9 @@ namespace cuda { #define STRIDE 32 #define BORDER 512 -static std::string DEFAULT_URL = "https://zenodo.org/record/7219754/files/"; +static std::string DEFAULT_URL = + "http://lf3-nlp-opensource.bytetos.com/obj/nlp-opensource/lightseq/" + "igemm_configs/"; static std::string DEFAULT_DIR = std::string(std::getenv("HOME")) + "/.lightseq/igemm_configs/"; static std::string IGEMM_T4_CONFIG = "igemm_T4.cfg"; diff --git a/lightseq/inference/model/gpt_encoder.cc.cu b/lightseq/inference/model/gpt_encoder.cc.cu index db5b1c48..b2652d7c 100644 --- a/lightseq/inference/model/gpt_encoder.cc.cu +++ b/lightseq/inference/model/gpt_encoder.cc.cu @@ -248,7 +248,7 @@ int GptEncoder::run_one_sample(int batch_size, int batch_seq_len) { ker_norm_layer_launcher<_DataType>( _batch_token_num, _tw._hidden_size, _stream, _p_d_query, _p_d_src_emb_wei[2], _p_d_src_emb_wei[3], _max_thread_per_block); - if (sample_one_token() == 0 || _batch_seq_len >= _tw._max_step) { + if (sample_one_token() == 0 || _batch_seq_len >= _batch_max_seq_len) { CHECK_GPU_ERROR(cudaMemcpyAsync(_p_d_sample_id_buf, _p_d_sample_id, _batch_token_num * sizeof(int), cudaMemcpyDeviceToDevice, _stream)); @@ -256,7 +256,7 @@ int GptEncoder::run_one_sample(int batch_size, int batch_seq_len) { return _batch_seq_len; } - while (_batch_seq_len < _tw._max_step) { + while (_batch_seq_len < _batch_max_seq_len) { #ifdef DEBUG_RESULT std::cout << "before sample:batch_size-" << _batch_size << " batch_seq_len-" << _batch_seq_len << std::endl; @@ -282,14 +282,13 @@ int GptEncoder::run_one_sample(int batch_size, int batch_seq_len) { ker_norm_layer_launcher<_DataType>( _batch_size, _tw._hidden_size, _stream, _p_d_query, _p_d_src_emb_wei[2], _p_d_src_emb_wei[3], _max_thread_per_block); + #ifdef DEBUG_RESULT print_vec(_p_d_query, "_p_d_query before logits", _batch_size * _tw._hidden_size - 10, _batch_size * _tw._hidden_size); - - if (sample_one_token_with_cache() == 0 || _batch_seq_len >= _tw._max_step) - break; #else + bool unfinish = sample_one_token_with_cache(); if (!unfinish && !_is_benchmark) break; #endif diff --git a/lightseq/inference/model/quant_gpt_encoder.cc.cu b/lightseq/inference/model/quant_gpt_encoder.cc.cu index 526f4f26..269dfdff 100644 --- a/lightseq/inference/model/quant_gpt_encoder.cc.cu +++ b/lightseq/inference/model/quant_gpt_encoder.cc.cu @@ -47,6 +47,7 @@ QuantGptEncoder::QuantGptEncoder( _h_sample_id(max_batch_size * tw._max_step, 0), _h_unfinished(1), _is_benchmark(false), + _algo_map(), _sm_gt_eq_80(getSMVersion() >= 80 ? true : false) { CHECK_GPU_ERROR(cublasLtCreate(&_cublas_lt_handle)); } @@ -179,11 +180,13 @@ void QuantGptEncoder::init_buffer() { _p_device_wei.push_back( to_gpu(_p_d_enc_wei[_weight_offset + 11], _tw._hidden_size, _stream)); + auto weight_layout = _sm_gt_eq_80 ? kColMajor : kColMajor32; + quantize_weight(_p_d_enc_wei[_weight_offset + 2], _int8_p_d_enc_wei[_layer_id * 4], _tw._hidden_size, _tw._hidden_size * 3, _quant_range / _enc_clip_max[_layer_id * 12], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); quantize_weight(_p_d_enc_wei[_weight_offset + 4], _int8_p_d_enc_wei[_layer_id * 4 + 1], _tw._hidden_size, @@ -195,7 +198,7 @@ void QuantGptEncoder::init_buffer() { _int8_p_d_enc_wei[_layer_id * 4 + 2], _tw._hidden_size, _tw._inner_size, _quant_range / _enc_clip_max[_layer_id * 12 + 2], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); quantize_weight(_p_d_enc_wei[_weight_offset + 10], _int8_p_d_enc_wei[_layer_id * 4 + 3], _tw._inner_size, @@ -306,6 +309,8 @@ int QuantGptEncoder::run_one_sample(int batch_size, _batch_size = batch_size; _batch_seq_len = batch_seq_len; _batch_token_num = batch_size * batch_seq_len; + _batch_max_seq_len = + min(_tw._max_step, batch_seq_len + _tw._extra_decode_length); CHECK_GPU_ERROR(cudaMemcpyAsync(_p_d_real_seq_len, _h_real_seq_len.data(), sizeof(int) * _batch_size, @@ -345,7 +350,7 @@ int QuantGptEncoder::run_one_sample(int batch_size, _p_d_self_v_cache2 = _p_d_self_v_cache1; _p_d_self_v_cache1 = ftmp; - if (sample_one_token() == 0 || _batch_seq_len >= _tw._max_step) { + if (sample_one_token() == 0 || _batch_seq_len >= _batch_max_seq_len) { CHECK_GPU_ERROR(cudaMemcpyAsync(_p_d_sample_id_buf, _p_d_sample_id, _batch_token_num * sizeof(int), cudaMemcpyDeviceToDevice, _stream)); @@ -353,7 +358,7 @@ int QuantGptEncoder::run_one_sample(int batch_size, return _batch_seq_len; } - while (_batch_seq_len < _tw._max_step) { + while (_batch_seq_len < _batch_max_seq_len) { #ifdef DEBUG_RESULT std::cout << "before sample:batch_size-" << _batch_size << " batch_seq_len-" << _batch_seq_len << std::endl; @@ -485,16 +490,25 @@ void QuantGptEncoder::self_attention() { _int8_ffn_in_buf, _p_device_wei[_weight_offset], _p_device_wei[_weight_offset + 1], _p_device_wei[_weight_offset + 5], _max_thread_per_block, _quant_range / _enc_clip_max[_layer_id * 12 + 4], - false, true); + false, !_sm_gt_eq_80); } - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size * 3, - _tw._hidden_size, 0, 0, 0, - _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / - (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm( + _int8_p_d_enc_wei[_layer_id * 4], _int8_ffn_in_buf, _int8_ffn_out_buf, + 1, _tw._hidden_size * 3, _batch_token_num, _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / + (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size * 3, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / + (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, + _stream, _sm_gt_eq_80); + } #ifdef DEBUG_RESULT print_vec(_int8_ffn_in_buf, "attn qkv in", 20); @@ -509,7 +523,7 @@ void QuantGptEncoder::self_attention() { _p_d_self_k_cache1[_layer_id], _p_d_self_v_cache1[_layer_id], _p_d_v, _batch_seq_len, _tw._dim_per_head, _tw._head_num, _max_thread_per_block, _enc_clip_max[_layer_id * 12 + 8] / _quant_range, - _quant_range / _enc_clip_max[_layer_id * 12 + 11], true); + _quant_range / _enc_clip_max[_layer_id * 12 + 11], !_sm_gt_eq_80); /* ---step 2. correlation = q * k, perform softmax on correlation--- */ CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( @@ -563,7 +577,7 @@ void QuantGptEncoder::self_attention() { _int8_ffn_in_buf, _p_d_query, _batch_token_num, _tw._hidden_size, _enc_clip_max[_layer_id * 12 + 9] / _quant_range, _quant_range / _enc_clip_max[_layer_id * 12 + 6], _max_thread_per_block, - _stream, false, false, true); + _stream, false, false, !_sm_gt_eq_80); return; } @@ -576,18 +590,27 @@ void QuantGptEncoder::self_attention_with_cache() { _batch_size, _tw._hidden_size, _stream, _p_d_query, _int8_ffn_in_buf, _p_device_wei[_weight_offset], _p_device_wei[_weight_offset + 1], _p_device_wei[_weight_offset + 5], _max_thread_per_block, - _quant_range / _enc_clip_max[_layer_id * 12 + 4], false, true); + _quant_range / _enc_clip_max[_layer_id * 12 + 4], false, !_sm_gt_eq_80); } /* ---step 1. qkv = ori_q * qkv_wei + bias, and reshape qkv for multi-head * gemm--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_size, _tw._hidden_size * 3, _tw._hidden_size, - 0, 0, 0, - _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / - (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm( + _int8_p_d_enc_wei[_layer_id * 4], _int8_ffn_in_buf, _int8_ffn_out_buf, + 1, _tw._hidden_size * 3, _batch_size, _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / + (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_size, _tw._hidden_size * 3, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12] * _enc_clip_max[_layer_id * 12 + 4] / + (_enc_clip_max[_layer_id * 12 + 8] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, + _stream, _sm_gt_eq_80); + } // get q, k, v by split and reshape qkv ker_arrange_qkv_with_cache_i8I_i8O_launcher<_DataType>( @@ -597,7 +620,7 @@ void QuantGptEncoder::self_attention_with_cache() { _p_d_self_v_cache1[_layer_id], _p_d_self_v_cache2[_layer_id], _batch_seq_len, _tw._dim_per_head, _tw._head_num, _enc_clip_max[_layer_id * 12 + 8] / _quant_range, - _quant_range / _enc_clip_max[_layer_id * 12 + 11], true); + _quant_range / _enc_clip_max[_layer_id * 12 + 11], !_sm_gt_eq_80); /* ---step 2. correlation = q * k, perform softmax on correlation correlation: [batch_size, heads_num, 1, batch_seq_len]--- */ @@ -630,20 +653,30 @@ void QuantGptEncoder::self_attention_with_cache() { _int8_ffn_in_buf, _p_d_query, _batch_size, _tw._hidden_size, _enc_clip_max[_layer_id * 12 + 9] / _quant_range, _quant_range / _enc_clip_max[_layer_id * 12 + 6], _max_thread_per_block, - _stream, false, false, true); + _stream, false, false, !_sm_gt_eq_80); return; } template void QuantGptEncoder::ffn_add_norm() { /* ---step 1. first ffn layer--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_token_num, _tw._inner_size, _tw._hidden_size, - 0, 0, 0, - _enc_clip_max[_layer_id * 12 + 2] * _enc_clip_max[_layer_id * 12 + 6] / - (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 2], _int8_ffn_in_buf, + _int8_ffn_out_buf, 1, _tw._inner_size, _batch_token_num, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12 + 2] * + _enc_clip_max[_layer_id * 12 + 6] / + (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_token_num, _tw._inner_size, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12 + 2] * _enc_clip_max[_layer_id * 12 + 6] / + (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], + _cublas_lt_handle, _stream, _sm_gt_eq_80); + } #ifdef DEBUG_RESULT print_vec(_int8_ffn_in_buf, "ffn1 in", 20); @@ -655,7 +688,7 @@ void QuantGptEncoder::ffn_add_norm() { _batch_token_num, _stream, _int8_ffn_out_buf, _int8_ffn_in_buf, _p_device_wei[_weight_offset + 9], _tw._inner_size, _enc_clip_max[_layer_id * 12 + 10] / _quant_range, - _quant_range / _enc_clip_max[_layer_id * 12 + 7], true, false); + _quant_range / _enc_clip_max[_layer_id * 12 + 7], !_sm_gt_eq_80, false); /* ---step 2. second ffn layer--- */ cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 3], _int8_ffn_in_buf, @@ -670,6 +703,7 @@ void QuantGptEncoder::ffn_add_norm() { const _DataType *scale_ptr, *bias_ptr, *res_bias_ptr; float clip_max, dequant_scale; + bool use_col32; dequant_scale = _enc_clip_max[_layer_id * 12 + 3] * _enc_clip_max[_layer_id * 12 + 7] / (_quant_range * _quant_range); @@ -678,19 +712,21 @@ void QuantGptEncoder::ffn_add_norm() { bias_ptr = _p_device_emb[3]; res_bias_ptr = nullptr; clip_max = _output_ln_clip_max; + use_col32 = true; } else { scale_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer]; bias_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer + 1]; res_bias_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer + 5]; clip_max = _enc_clip_max[(_layer_id + 1) * 12 + 4]; + use_col32 = !_sm_gt_eq_80; } ker_residual_bias_ln_i32I_i8O_launcher<_DataType>( _int32_ffn_out_buf, scale_ptr, bias_ptr, res_bias_ptr, _int8_ffn_in_buf, _p_d_query, _batch_token_num, _tw._hidden_size, dequant_scale, _quant_range / clip_max, _max_thread_per_block, _stream, false, false, - true, _scaled_ffn2_colsum[_layer_id]); + use_col32, _scaled_ffn2_colsum[_layer_id]); return; } @@ -698,19 +734,29 @@ void QuantGptEncoder::ffn_add_norm() { template void QuantGptEncoder::ffn_add_norm_with_cache() { /* ---step 1. first ffn layer--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_size, _tw._inner_size, _tw._hidden_size, 0, - 0, 0, - _enc_clip_max[_layer_id * 12 + 2] * _enc_clip_max[_layer_id * 12 + 6] / - (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 2], _int8_ffn_in_buf, + _int8_ffn_out_buf, 1, _tw._inner_size, _batch_size, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 12 + 2] * + _enc_clip_max[_layer_id * 12 + 6] / + (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_size, _tw._inner_size, _tw._hidden_size, 0, + 0, 0, + _enc_clip_max[_layer_id * 12 + 2] * _enc_clip_max[_layer_id * 12 + 6] / + (_enc_clip_max[_layer_id * 12 + 10] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], + _cublas_lt_handle, _stream, _sm_gt_eq_80); + } ker_bias_gelu_i8I_i8O_launcher<_DataType>( _batch_size, _stream, _int8_ffn_out_buf, _int8_ffn_in_buf, _p_device_wei[_weight_offset + 9], _tw._inner_size, _enc_clip_max[_layer_id * 12 + 10] / _quant_range, - _quant_range / _enc_clip_max[_layer_id * 12 + 7], true, false); + _quant_range / _enc_clip_max[_layer_id * 12 + 7], !_sm_gt_eq_80, false); /* ---step 2. second ffn layer--- */ cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 3], _int8_ffn_in_buf, @@ -719,6 +765,7 @@ void QuantGptEncoder::ffn_add_norm_with_cache() { const _DataType *scale_ptr, *bias_ptr, *res_bias_ptr; float clip_max, dequant_scale; + bool use_col32; dequant_scale = _enc_clip_max[_layer_id * 12 + 3] * _enc_clip_max[_layer_id * 12 + 7] / (_quant_range * _quant_range); @@ -727,19 +774,21 @@ void QuantGptEncoder::ffn_add_norm_with_cache() { bias_ptr = _p_device_emb[3]; res_bias_ptr = nullptr; clip_max = _output_ln_clip_max; + use_col32 = true; } else { scale_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer]; bias_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer + 1]; res_bias_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer + 5]; clip_max = _enc_clip_max[(_layer_id + 1) * 12 + 4]; + use_col32 = !_sm_gt_eq_80; } ker_residual_bias_ln_i32I_i8O_launcher<_DataType>( _int32_ffn_out_buf, scale_ptr, bias_ptr, res_bias_ptr, _int8_ffn_in_buf, _p_d_query, _batch_size, _tw._hidden_size, dequant_scale, _quant_range / clip_max, _max_thread_per_block, _stream, false, false, - true, _scaled_ffn2_colsum[_layer_id]); + use_col32, _scaled_ffn2_colsum[_layer_id]); return; } diff --git a/lightseq/inference/model/quant_gpt_encoder.h b/lightseq/inference/model/quant_gpt_encoder.h index 1b450f40..67c032d6 100644 --- a/lightseq/inference/model/quant_gpt_encoder.h +++ b/lightseq/inference/model/quant_gpt_encoder.h @@ -16,6 +16,7 @@ #include "../proto/quant_gpt_weight.h" #include "../tools/util.h" +#include "cublas_algo_map.h" namespace lightseq { namespace cuda { @@ -45,6 +46,7 @@ class QuantGptEncoder { cudaStream_t _cache_stream; cublasHandle_t _hd; cublasLtHandle_t _cublas_lt_handle; + cublasAlgoMap _algo_map; const bool _sm_gt_eq_80; const _DataType _fone; @@ -110,6 +112,7 @@ class QuantGptEncoder { int _batch_size; int _batch_token_num; + int _batch_max_seq_len; int _layer_id; int _weight_offset; bool _is_benchmark; diff --git a/lightseq/inference/model/quant_vit_encoder.cc.cu b/lightseq/inference/model/quant_vit_encoder.cc.cu index 23c2ed30..1aee1045 100644 --- a/lightseq/inference/model/quant_vit_encoder.cc.cu +++ b/lightseq/inference/model/quant_vit_encoder.cc.cu @@ -35,6 +35,7 @@ QuantVitEncoder::QuantVitEncoder( _atten_scaler((_DataType)sqrt(1.f / tw._dim_per_head)), _max_batch_dim(max_batch_size * tw._max_step * tw._hidden_size), _max_thread_per_block(1024), + _algo_map(), _sm_gt_eq_80(getSMVersion() >= 80 ? true : false) { CHECK_GPU_ERROR(cublasLtCreate(&_cublas_lt_handle)); } @@ -102,29 +103,31 @@ void QuantVitEncoder::init_buffer() { _p_device_wei.push_back( to_gpu(_p_d_enc_wei[_weight_offset + 11], _tw._hidden_size, _stream)); + auto weight_layout = _sm_gt_eq_80 ? kColMajor : kColMajor32; + quantize_weight(_p_d_enc_wei[_weight_offset + 2], _int8_p_d_enc_wei[_layer_id * 4], _tw._hidden_size, _tw._hidden_size * 3, _quant_range / _enc_clip_max[_layer_id * 11], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); quantize_weight(_p_d_enc_wei[_weight_offset + 4], _int8_p_d_enc_wei[_layer_id * 4 + 1], _tw._hidden_size, _tw._hidden_size, _quant_range / _enc_clip_max[_layer_id * 11 + 1], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); quantize_weight(_p_d_enc_wei[_weight_offset + 8], _int8_p_d_enc_wei[_layer_id * 4 + 2], _tw._hidden_size, _tw._inner_size, _quant_range / _enc_clip_max[_layer_id * 11 + 2], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); quantize_weight(_p_d_enc_wei[_weight_offset + 10], _int8_p_d_enc_wei[_layer_id * 4 + 3], _tw._inner_size, _tw._hidden_size, _quant_range / _enc_clip_max[_layer_id * 11 + 3], _stream, - _cublas_lt_handle); + _cublas_lt_handle, weight_layout); if (_tw._use_gelu) { _scaled_ffn2_colsum[_layer_id] = nullptr; @@ -239,7 +242,7 @@ void QuantVitEncoder::self_attention() { _int8_ffn_in_buf, _p_device_wei[_weight_offset], _p_device_wei[_weight_offset + 1], _p_device_wei[_weight_offset + 5], _max_thread_per_block, _quant_range / _enc_clip_max[_layer_id * 11 + 4], - _tw._is_post_ln, true); + _tw._is_post_ln, !_sm_gt_eq_80); } CHECK_GPU_ERROR(cudaGetLastError()); @@ -256,20 +259,29 @@ void QuantVitEncoder::self_attention() { /* ---step 1. qkv = ori_q * qkv_wei + bias, and reshape qkv for multi-head * gemm--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size * 3, - _tw._hidden_size, 0, 0, 0, - _enc_clip_max[_layer_id * 11] * _enc_clip_max[_layer_id * 11 + 4] / - (_enc_clip_max[_layer_id * 11 + 8] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm( + _int8_p_d_enc_wei[_layer_id * 4], _int8_ffn_in_buf, _int8_ffn_out_buf, + 1, _tw._hidden_size * 3, _batch_token_num, _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11] * _enc_clip_max[_layer_id * 11 + 4] / + (_enc_clip_max[_layer_id * 11 + 8] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size * 3, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11] * _enc_clip_max[_layer_id * 11 + 4] / + (_enc_clip_max[_layer_id * 11 + 8] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4], _cublas_lt_handle, + _stream, _sm_gt_eq_80); + } // get q, k, v by split and reshape qkv ker_arrange_encself_qkv_i8I_launcher<_DataType>( _batch_token_num, _tw._hidden_size, _stream, _int8_ffn_out_buf, _p_device_wei[_weight_offset + 3], _p_d_q, _max_batch_dim, _batch_seq_len, _tw._dim_per_head, _tw._head_num, _max_thread_per_block, - _enc_clip_max[_layer_id * 11 + 8] / _quant_range, true); + _enc_clip_max[_layer_id * 11 + 8] / _quant_range, !_sm_gt_eq_80); /* ---step 2. correlation = q * k, perform softmax on correlation--- */ CHECK_GPU_ERROR(cublasGemmStridedBatchedEx( @@ -298,7 +310,7 @@ void QuantVitEncoder::self_attention() { ker_arrange_atten_output_i8O_launcher<_DataType>( _batch_token_num, _tw._hidden_size, _stream, _p_d_q, _int8_ffn_in_buf, _batch_seq_len, _tw._dim_per_head, _tw._head_num, _max_thread_per_block, - _quant_range / _enc_clip_max[_layer_id * 11 + 5], true); + _quant_range / _enc_clip_max[_layer_id * 11 + 5], !_sm_gt_eq_80); #ifdef DEBUG_RESULT for (int i = 0; i < _batch_size; i++) { // batch_id @@ -312,13 +324,23 @@ void QuantVitEncoder::self_attention() { #endif /* ---step 4. new_q = ori_q + new_q * output_wei--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size, - _tw._hidden_size, 0, 0, 0, - _enc_clip_max[_layer_id * 11 + 1] * _enc_clip_max[_layer_id * 11 + 5] / - (_enc_clip_max[_layer_id * 11 + 9] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 1], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 1], _int8_ffn_in_buf, + _int8_ffn_out_buf, 1, _tw._hidden_size, _batch_token_num, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11 + 1] * + _enc_clip_max[_layer_id * 11 + 5] / + (_enc_clip_max[_layer_id * 11 + 9] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11 + 1] * _enc_clip_max[_layer_id * 11 + 5] / + (_enc_clip_max[_layer_id * 11 + 9] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 1], + _cublas_lt_handle, _stream, _sm_gt_eq_80); + } #ifdef DEBUG_RESULT for (int i = 0; i < _batch_size; i++) { // batch_id @@ -337,7 +359,7 @@ void QuantVitEncoder::self_attention() { _int8_ffn_in_buf, _p_d_output, _batch_token_num, _tw._hidden_size, _enc_clip_max[_layer_id * 11 + 9] / _quant_range, _quant_range / _enc_clip_max[_layer_id * 11 + 6], _max_thread_per_block, - _stream, _tw._is_post_ln, true, true); + _stream, _tw._is_post_ln, !_sm_gt_eq_80, !_sm_gt_eq_80); return; } @@ -356,27 +378,38 @@ void QuantVitEncoder::ffn_add_norm() { #endif /* ---step 1. first ffn layer--- */ - cublasLtMM_withAlgo_i8IO( - _int8_ffn_out_buf, 1, _batch_token_num, _tw._inner_size, _tw._hidden_size, - 0, 0, 0, - _enc_clip_max[_layer_id * 11 + 2] * _enc_clip_max[_layer_id * 11 + 6] / - (_enc_clip_max[_layer_id * 11 + 10] * _quant_range), - _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 2], _int8_ffn_in_buf, + _int8_ffn_out_buf, 1, _tw._inner_size, _batch_token_num, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11 + 2] * + _enc_clip_max[_layer_id * 11 + 6] / + (_enc_clip_max[_layer_id * 11 + 10] * _quant_range), + _cublas_lt_handle, _stream, _algo_map); + } else { + cublasLtMM_withAlgo_i8IO( + _int8_ffn_out_buf, 1, _batch_token_num, _tw._inner_size, + _tw._hidden_size, 0, 0, 0, + _enc_clip_max[_layer_id * 11 + 2] * _enc_clip_max[_layer_id * 11 + 6] / + (_enc_clip_max[_layer_id * 11 + 10] * _quant_range), + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 2], + _cublas_lt_handle, _stream, _sm_gt_eq_80); + } if (_tw._use_gelu) { ker_bias_gelu_i8I_i8O_launcher<_DataType>( _batch_token_num, _stream, _int8_ffn_out_buf, _int8_ffn_in_buf, _p_device_wei[_weight_offset + 9], _tw._inner_size, _enc_clip_max[_layer_id * 11 + 10] / _quant_range, - _quant_range / _enc_clip_max[_layer_id * 11 + 7], true, true); + _quant_range / _enc_clip_max[_layer_id * 11 + 7], !_sm_gt_eq_80, + !_sm_gt_eq_80); } else { ker_bias_relu_i8I_i8O_launcher<_DataType>( _batch_token_num, _stream, _int8_ffn_out_buf, _int8_ffn_in_buf, _p_device_wei[_weight_offset + 9], _tw._inner_size, _enc_clip_max[_layer_id * 11 + 10] / _quant_range, _quant_range / _enc_clip_max[_layer_id * 11 + 7], - _enc_clip_max[_layer_id * 11 + 7], true, true, true); + _enc_clip_max[_layer_id * 11 + 7], !_sm_gt_eq_80, !_sm_gt_eq_80, true); } #ifdef DEBUG_RESULT @@ -391,10 +424,17 @@ void QuantVitEncoder::ffn_add_norm() { #endif /* ---step 2. second ffn layer--- */ - cublasLtMM_withAlgo(_int32_ffn_out_buf, 1, _batch_token_num, _tw._hidden_size, - _tw._inner_size, 0, 0, 0, _int8_ffn_in_buf, - _int8_p_d_enc_wei[_layer_id * 4 + 3], _cublas_lt_handle, - _stream, _sm_gt_eq_80); + if (_sm_gt_eq_80) { + cublaslt_gemm(_int8_p_d_enc_wei[_layer_id * 4 + 3], _int8_ffn_in_buf, + _int32_ffn_out_buf, 1, _tw._hidden_size, _batch_token_num, + _tw._inner_size, 0, 0, 0, 1, _cublas_lt_handle, _stream, + _algo_map); + } else { + cublasLtMM_withAlgo(_int32_ffn_out_buf, 1, _batch_token_num, + _tw._hidden_size, _tw._inner_size, 0, 0, 0, + _int8_ffn_in_buf, _int8_p_d_enc_wei[_layer_id * 4 + 3], + _cublas_lt_handle, _stream, _sm_gt_eq_80); + } const _DataType *scale_ptr, *bias_ptr, *res_bias_ptr; float clip_max, dequant_scale; @@ -414,7 +454,8 @@ void QuantVitEncoder::ffn_add_norm() { ker_residual_bias_ln_i32I_launcher<_DataType>( _int32_ffn_out_buf, scale_ptr, bias_ptr, _p_d_output, _p_d_output, _batch_token_num, _tw._hidden_size, dequant_scale, - _max_thread_per_block, _stream, true, _scaled_ffn2_colsum[_layer_id]); + _max_thread_per_block, _stream, !_sm_gt_eq_80, + _scaled_ffn2_colsum[_layer_id]); } else { scale_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer]; bias_ptr = _p_device_wei[(_layer_id + 1) * _tw._weight_per_enc_layer + 1]; @@ -426,7 +467,8 @@ void QuantVitEncoder::ffn_add_norm() { _int32_ffn_out_buf, scale_ptr, bias_ptr, res_bias_ptr, _int8_ffn_in_buf, _p_d_output, _batch_token_num, _tw._hidden_size, dequant_scale, _quant_range / clip_max, _max_thread_per_block, _stream, - _tw._is_post_ln, true, true, _scaled_ffn2_colsum[_layer_id]); + _tw._is_post_ln, !_sm_gt_eq_80, !_sm_gt_eq_80, + _scaled_ffn2_colsum[_layer_id]); #ifdef DEBUG_RESULT for (int i = 0; i < _batch_size; i++) { // batch_id diff --git a/lightseq/inference/model/quant_vit_encoder.h b/lightseq/inference/model/quant_vit_encoder.h index fb65b1f1..bb9f283e 100644 --- a/lightseq/inference/model/quant_vit_encoder.h +++ b/lightseq/inference/model/quant_vit_encoder.h @@ -15,6 +15,7 @@ #include "../proto/quant_vit_weight.h" #include "../tools/util.h" +#include "cublas_algo_map.h" /** @file @@ -46,6 +47,7 @@ class QuantVitEncoder { cudaStream_t _stream; cublasHandle_t _hd; cublasLtHandle_t _cublas_lt_handle; + cublasAlgoMap _algo_map; const bool _sm_gt_eq_80; const _DataType _fone; diff --git a/lightseq/inference/proto/quant_gpt.proto b/lightseq/inference/proto/quant_gpt.proto index ba2c63b7..1ba41664 100644 --- a/lightseq/inference/proto/quant_gpt.proto +++ b/lightseq/inference/proto/quant_gpt.proto @@ -62,6 +62,7 @@ message QuantGptModelConf { float topp = 4; int32 topk = 5; int32 eos_id = 6; + int32 extra_decode_length = 7; } message QuantGpt { diff --git a/lightseq/inference/proto/quant_gpt_weight.cc b/lightseq/inference/proto/quant_gpt_weight.cc index 9da21402..9fd28b4a 100644 --- a/lightseq/inference/proto/quant_gpt_weight.cc +++ b/lightseq/inference/proto/quant_gpt_weight.cc @@ -275,6 +275,15 @@ void QuantGptWeight::hdf5_get_model_config(hid_t hdf5_file) { _sampling_method = _sampling_method_read; } + int _extra_decode_length_read; + read_hdf5_dataset_scalar(hdf5_file, "model_conf/extra_decode_length", + H5T_NATIVE_INT, &_extra_decode_length_read); + if (_extra_decode_length_read > 0) { + _extra_decode_length = _extra_decode_length_read; + } else { + _extra_decode_length = _max_step; + } + int _topk_read; read_hdf5_dataset_scalar(hdf5_file, "model_conf/topk", H5T_NATIVE_INT, &_topk_read); diff --git a/lightseq/inference/proto/quant_gpt_weight.h b/lightseq/inference/proto/quant_gpt_weight.h index 748b566b..33fd3ae2 100644 --- a/lightseq/inference/proto/quant_gpt_weight.h +++ b/lightseq/inference/proto/quant_gpt_weight.h @@ -82,6 +82,7 @@ class QuantGptWeight { int _hidden_size; int _inner_size; int _max_step; + int _extra_decode_length; int _src_vocab_size; int _n_enc_layer; // number of encoder layer int _dim_per_head;