From 29eaec985c3fb1062a2e8201ced69fdf218e5135 Mon Sep 17 00:00:00 2001 From: luoyao Date: Thu, 26 Dec 2024 20:38:40 +0800 Subject: [PATCH] update llama model --- src/models/llm/llama/llama3.inl | 377 ++++++++++++++++++++++++++------ 1 file changed, 308 insertions(+), 69 deletions(-) diff --git a/src/models/llm/llama/llama3.inl b/src/models/llm/llama/llama3.inl index e7ae8cf..54f094b 100644 --- a/src/models/llm/llama/llama3.inl +++ b/src/models/llm/llama/llama3.inl @@ -87,9 +87,19 @@ public: * */ ~Impl() { - llama_sampler_free(_m_sampler); - llama_free(_m_ctx); - llama_free_model(_m_model); + if (nullptr != _m_smpl_chain) { + llama_sampler_free(_m_smpl_chain); + _m_smpl_chain = nullptr; + } + if (nullptr != _m_ctx) { + llama_free(_m_ctx); + _m_ctx = nullptr; + } + if (nullptr != _m_model) { + llama_free_model(_m_model); + _m_model = nullptr; + } + llama_backend_free(); } /*** @@ -200,14 +210,21 @@ private: // llama model context llama_context* _m_ctx = nullptr; // llama sampler params - llama_sampler_chain_params _m_sampler_params = llama_sampler_chain_default_params(); + common_params_sampling _m_smpl_params{}; // llama sampler - llama_sampler* _m_sampler = nullptr; + llama_sampler* _m_smpl_chain = nullptr; + llama_sampler* _m_smpl_grmr = nullptr; // init flag bool _m_successfully_initialized = false; private: + /*** + * + * @param need_grama + */ + StatusCode init_sampler(); + /*** * * @param prompt_tokens @@ -215,6 +232,15 @@ private: * @return */ StatusCode llama_generate(std::vector& prompt_tokens, std::string& generate_out); + + /*** + * + * @param idx + * @param out_sampled_token + * @param grammar_first + * @return + */ + StatusCode llama_sample(int idx, llama_token& out_sampled_token, bool grammar_first=false); }; /*** @@ -245,12 +271,22 @@ StatusCode Llama3::Impl::init(const toml::value& config) { return StatusCode::MODEL_INIT_FAILED; } - // init model + // init llama backend + llama_backend_init(); + ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; + llama_numa_init(numa); + + // load llama model auto n_gpu_layers = static_cast(model_cfg.at("n_gpu_layers").as_integer()); auto main_gpu_device_id = static_cast(model_cfg.at("main_gpu_device").as_integer()); - _m_model_params.n_gpu_layers = n_gpu_layers; + _m_model_params.devices = nullptr; // all available devices + _m_model_params.n_gpu_layers = n_gpu_layers; // number of layers to store in VRAM _m_model_params.main_gpu = main_gpu_device_id; + _m_model_params.split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split model cross gpus _m_model_params.vocab_only = false; + _m_model_params.use_mmap = true; // use mmap for faster loads + _m_model_params.use_mlock = false; // use mlock to keep model in memory + _m_model_params.check_tensors = false; if (model_cfg.contains("vocab_only")) { _m_model_params.vocab_only = model_cfg.at("vocab_only").as_boolean(); } @@ -260,39 +296,50 @@ StatusCode Llama3::Impl::init(const toml::value& config) { return StatusCode::MODEL_INIT_FAILED; } - if (!_m_model_params.vocab_only) { - // init sampler - _m_sampler = llama_sampler_chain_init(_m_sampler_params); - if (_m_sampler == nullptr) { - LOG(ERROR) << "failed to create the llama sampler"; - return StatusCode::MODEL_INIT_FAILED; - } - auto temp = static_cast(model_cfg.at("sampler_temp").as_floating()); - auto init_min_p = 0.05f; - auto min_keep = 1; - llama_sampler_chain_add(_m_sampler, llama_sampler_init_min_p(init_min_p, min_keep)); - llama_sampler_chain_add(_m_sampler, llama_sampler_init_temp(temp)); - llama_sampler_chain_add(_m_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - - // init ctx params - if (!config.contains("CONTEXT")) { - LOG(ERROR) << "Config file does not contain CONTEXT section"; - _m_successfully_initialized = false; - return StatusCode::MODEL_INIT_FAILED; - } - toml::value ctx_cfg = config.at("CONTEXT"); - auto ctx_size = llama_n_ctx_train(_m_model); - if (ctx_cfg.contains("context_size")) { - ctx_size = static_cast(ctx_cfg.at("context_size").as_integer()); - } - _m_ctx_params.n_ctx = ctx_size; - _m_ctx_params.n_batch = ctx_size; - _m_ctx = llama_new_context_with_model(_m_model, _m_ctx_params); - if (_m_ctx == nullptr) { - LOG(ERROR) << "failed to create the llama_context"; - return StatusCode::MODEL_INIT_FAILED; - } + // init llama model ctx + if (!config.contains("CONTEXT")) { + LOG(ERROR) << "Config file does not contain CONTEXT section"; + _m_successfully_initialized = false; + return StatusCode::MODEL_INIT_FAILED; + } + toml::value ctx_cfg = config.at("CONTEXT"); + auto ctx_size = llama_n_ctx_train(_m_model); + if (ctx_cfg.contains("context_size")) { + ctx_size = static_cast(ctx_cfg.at("context_size").as_integer()); + } + _m_ctx_params.n_ctx = ctx_size <= llama_n_ctx_train(_m_model) ? ctx_size : llama_n_ctx_train(_m_model); // context size + _m_ctx_params.n_batch = _m_ctx_params.n_ctx / 2; // logical batch size for prompt processing (must be >=32 to use BLAS) + _m_ctx_params.n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) + _m_ctx_params.logits_all = false; // return logits for all tokens in the batch + _m_ctx_params.embeddings = false; // get only sentence embedding + _m_ctx_params.flash_attn = false; // flash attention + _m_ctx_params.no_perf = true; // no performance metrics + _m_ctx_params.offload_kqv = true; // disable KV offloading + if (_m_model_params.vocab_only) { + _m_ctx_params = llama_context_default_params(); + } + _m_ctx = llama_new_context_with_model(_m_model, _m_ctx_params); + if (_m_ctx == nullptr) { + LOG(ERROR) << "failed to create the llama_context"; + return StatusCode::MODEL_INIT_FAILED; + } + + // init sampler + auto smpl_cfg = config.at("SAMPLER"); + _m_smpl_params.min_keep = static_cast(smpl_cfg.at("min_keep").as_integer()); + _m_smpl_params.top_k = static_cast(smpl_cfg.at("top_k").as_integer()); + _m_smpl_params.top_p = static_cast(smpl_cfg.at("top_p").as_floating()); + _m_smpl_params.min_p = static_cast(smpl_cfg.at("min_p").as_floating()); + _m_smpl_params.temp = static_cast(smpl_cfg.at("temp").as_floating()); + _m_smpl_params.no_perf = smpl_cfg.at("no_perf").as_boolean(); + init_sampler(); + + std::string result = "logits "; + for (int i = 0; i < llama_sampler_chain_n(_m_smpl_chain); i++) { + const auto * smpl = llama_sampler_chain_get(_m_smpl_chain, i); + result += std::string("-> ") + llama_sampler_name(smpl) + " "; } + LOG(INFO) << result; _m_successfully_initialized = true; return StatusCode::OK; @@ -329,7 +376,7 @@ StatusCode Llama3::Impl::run(const INPUT& in, OUTPUT& out) { out = llama_impl::transform_output(generate_out); return status; - } else if constexpr(std::is_same&>::value) { + } else if constexpr(std::is_same& >::value) { // run llama3 generate std::string generate_out; auto status = llama_generate(in, generate_out); @@ -511,7 +558,19 @@ StatusCode Llama3::Impl::get_embedding( */ template StatusCode Llama3::Impl::text_completion(const std::string &prompt, std::string &generate_output) { - return run(prompt, generate_output); + if constexpr(std::is_same::value) { + return run(prompt, generate_output); + } else if constexpr(std::is_same&>::value) { + std::vector tokens; + auto status = tokenize(prompt, tokens, true); + if (status != StatusCode::OK) { + return status; + } + return run(tokens, generate_output); + } else { + LOG(ERROR) << "wrong input data type"; + return StatusCode::MODEL_RUN_SESSION_FAILED; + } } /*** @@ -526,24 +585,28 @@ template StatusCode Llama3::Impl::chat_completion(Dialog &dialog, std::string &generate_output) { // template format dialog std::string fmt_prompt; - auto status = apply_chat_template(dialog, false, fmt_prompt); + bool add_ass = dialog.messages.back().role == "user"; + auto status = apply_chat_template(dialog, add_ass, fmt_prompt); if (status != StatusCode::OK) { LOG(ERROR) << "apply chat template for dialog failed, status code: " << status; return status; } - // tokenize prompts - std::vector prompt_tokens; - status = tokenize(fmt_prompt, prompt_tokens); - if (status != StatusCode::OK) { - LOG(ERROR) << "tokenize dialog failed, status code: " << status; - return status; + if constexpr(std::is_same::value) { + return run(fmt_prompt, generate_output); + } else if constexpr(std::is_same&>::value) { + // tokenize prompts + std::vector prompt_tokens; + status = tokenize(fmt_prompt, prompt_tokens, true); + if (status != StatusCode::OK) { + LOG(ERROR) << "tokenize dialog failed, status code: " << status; + return status; + } + return run(prompt_tokens, generate_output); + } else { + LOG(ERROR) << "wrong input data type"; + return StatusCode::MODEL_RUN_SESSION_FAILED; } - - // chat completion - status = run(prompt_tokens, generate_output); - - return status; } /*** @@ -559,21 +622,23 @@ template StatusCode Llama3::Impl::apply_chat_template(const Dialog &dialog, bool add_ass, std::string &out_formatted_str) { // allocate string buffer int32_t alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; + bool fallback = false; // indicate if we must fallback to default chatml for (auto& msg : dialog.messages) { - alloc_size += static_cast(static_cast((std::strlen(msg.role) + std::strlen(msg.content))) * 1.25); + chat.push_back({msg.role.c_str(), msg.content.c_str()}); + alloc_size += static_cast(static_cast(msg.role.size() + msg.content.size()) * 1.25f); } std::vector buf(alloc_size); // run the first time to get the total output length int32_t res = llama_chat_apply_template( - _m_model, nullptr, dialog.messages.data(), dialog.size(), add_ass, buf.data(), alloc_size); + _m_model, nullptr, chat.data(), chat.size(), add_ass, buf.data(), static_cast(buf.size())); // error: chat template is not supported if (res < 0) { LOG(WARNING) << "failed to apply model's default chat template. Will try again with chatml template"; - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), alloc_size); + res = llama_chat_apply_template( + nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), static_cast(buf.size())); fallback = true; if (res < 0) { LOG(ERROR) << "failed to apply default chatml template"; @@ -587,7 +652,7 @@ StatusCode Llama3::Impl::apply_chat_template(const Dialog &dialog res = llama_chat_apply_template( fallback ? nullptr : _m_model, fallback ? "chatml" : nullptr, - chat.data(), chat.size(), add_ass, buf.data(), alloc_size); + chat.data(), chat.size(), add_ass, buf.data(), static_cast(buf.size())); } if (res < 0) { LOG(ERROR) << "failed to apply default chatml template"; @@ -624,6 +689,97 @@ void Llama3::Impl::clear_kv_cache_cell() const { llama_kv_cache_clear(_m_ctx); } +/*** + * + * @tparam INPUT + * @tparam OUTPUT + * @param need_grama + */ +template +StatusCode Llama3::Impl::init_sampler() { + auto lsmpl_params = llama_sampler_chain_default_params(); + lsmpl_params.no_perf = _m_smpl_params.no_perf; + _m_smpl_chain = llama_sampler_chain_init(lsmpl_params); + + // add sampler to chain + llama_sampler_chain_add( + _m_smpl_chain, + llama_sampler_init_logit_bias( + llama_n_vocab(_m_model), + _m_smpl_params.logit_bias.size(), + _m_smpl_params.logit_bias.data() + ) + ); + llama_sampler_chain_add( + _m_smpl_chain, + llama_sampler_init_penalties( + llama_n_vocab (_m_model), + llama_token_eos(_m_model), + llama_token_nl (_m_model), + _m_smpl_params.penalty_last_n, + _m_smpl_params.penalty_repeat, + _m_smpl_params.penalty_freq, + _m_smpl_params.penalty_present, + _m_smpl_params.penalize_nl, + _m_smpl_params.ignore_eos + ) + ); + auto& params = _m_smpl_params; + if (params.mirostat == 0) { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto& str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_dry(_m_model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_top_k(params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_top_p(params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_min_p(params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_typical(params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_infill(_m_model)); + break; + default: + LOG(WARNING) << fmt::format("unknown sampler type: {}", static_cast(cnstr)); + break; + } + } + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_dist(params.seed)); + } else if (params.mirostat == 1) { + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_mirostat(llama_n_vocab(_m_model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + } else if (params.mirostat == 2) { + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(_m_smpl_chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); + } else { + LOG(ERROR) << "unknown mirostat version"; + return StatusCode::MODEL_INIT_FAILED; + } + _m_smpl_grmr = llama_sampler_init_grammar(_m_model, params.grammar.c_str(), "root"); + + return StatusCode::OK; +} + /*** * * @tparam INPUT @@ -645,6 +801,7 @@ StatusCode Llama3::Impl::llama_generate(std::vector& return StatusCode::LLM_CONTEXT_SIZE_EXCEEDED; } + // run decoder model auto status = llama_decode(_m_ctx, batch); if (status == 1) { LOG(WARNING) << "llama generate failed. could not find a KV slot for the batch " @@ -654,22 +811,33 @@ StatusCode Llama3::Impl::llama_generate(std::vector& return StatusCode::MODEL_RUN_SESSION_FAILED; } - // sample the next token - new_token_id = llama_sampler_sample(_m_sampler, _m_ctx, -1); - - // is it an end of generation? + // sample from model output logits + StatusCode sample_status = llama_sample(-1, new_token_id, false); + if (sample_status != StatusCode::OK) { + LOG(ERROR) << "llama sample failed"; + return StatusCode::MODEL_RUN_SESSION_FAILED; + } if (llama_token_is_eog(_m_model, new_token_id)) { break; } - // convert the token to a string, print it and add it to the response - char buf[256]; - int n = llama_token_to_piece(_m_model, new_token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - LOG(ERROR) << "failed to convert token to piece"; - return StatusCode::MODEL_RUN_SESSION_FAILED; + // convert token to output string + std::string piece; + piece.resize(piece.capacity()); + bool enable_special_token_output = false; + auto n_chars = llama_token_to_piece( + _m_model, new_token_id, &piece[0], static_cast(piece.size()), 0, enable_special_token_output); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece( + _m_model, new_token_id, &piece[0], static_cast(piece.size()), 0, enable_special_token_output); + if (check != -n_chars) { + LOG(ERROR) << fmt::format("decode token to string failed, check nums: {}, n_chars: {}", check, -n_chars); + return StatusCode::MODEL_RUN_SESSION_FAILED; + } + } else { + piece.resize(n_chars); } - std::string piece(buf, n); generate_out += piece; // prepare the next batch with the sampled token @@ -679,6 +847,77 @@ StatusCode Llama3::Impl::llama_generate(std::vector& return StatusCode::OK; } +/*** + * + * @tparam INPUT + * @tparam OUTPUT + * @param idx + * @param out_sampled_token + * @param grammar_first + * @return + */ +template +StatusCode Llama3::Impl::llama_sample(int idx, llama_token &out_sampled_token, bool grammar_first) { + // get logits + std::vector cur; + llama_token_data_array cur_p; + auto* logits = llama_get_logits_ith(_m_ctx, idx); + int n_vocab = llama_n_vocab(llama_get_model(_m_ctx)); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + cur_p = { cur.data(), cur.size(), -1, false }; + + // chain sample + if (grammar_first && _m_smpl_grmr != nullptr) { + llama_sampler_apply(_m_smpl_grmr, &cur_p); + } + llama_sampler_apply(_m_smpl_chain, &cur_p); + if (cur_p.selected == -1) { + LOG(ERROR) << "no selected token during sampling - check your sampling configuration"; + return StatusCode::MODEL_RUN_SESSION_FAILED; + } + const llama_token id = cur_p.data[cur_p.selected].id; + if (grammar_first && _m_smpl_grmr != nullptr) { + out_sampled_token = id; + return StatusCode::OK; + } + + // check if sampled token fits the grammar + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; + llama_sampler_apply(_m_smpl_grmr, &single_token_data_array); + bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + out_sampled_token = id; + return StatusCode::OK; + } + + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain + logits = llama_get_logits_ith(_m_ctx, idx); + n_vocab = llama_n_vocab(llama_get_model(_m_ctx)); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + cur_p = { cur.data(), cur.size(), -1, false }; + llama_sampler_apply(_m_smpl_grmr, &cur_p); + llama_sampler_apply(_m_smpl_chain, &cur_p); + if (cur_p.selected == -1) { + LOG(ERROR) << "no selected token during sampling - check your sampling configuration"; + return StatusCode::MODEL_RUN_SESSION_FAILED; + } + out_sampled_token = cur_p.data[cur_p.selected].id; + + // sampler accept + llama_sampler_accept(_m_smpl_grmr, out_sampled_token); + llama_sampler_accept(_m_smpl_chain, out_sampled_token); + + return StatusCode::OK; +} + /************* Export Function Sets *************/ /***