From a39ab216aa624308fda7fa84439c6b61dc98b87a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 2 Oct 2024 15:49:55 +0200 Subject: [PATCH] llama : reduce compile time and binary size (#9712) * llama : speed up compile time * fix build * fix build (2) --- src/llama.cpp | 22 +++++++++++----------- src/unicode-data.cpp | 10 ++++++---- src/unicode-data.h | 8 ++++---- src/unicode.cpp | 21 ++++++++++++++------- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 4c0a1bb618277..af19e5c863399 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -610,7 +610,7 @@ enum llm_tensor { LLM_TENSOR_CLS_OUT, }; -static const std::map> LLM_TENSOR_NAMES = { +static const std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, { @@ -1566,32 +1566,32 @@ struct LLM_TN { return LLM_TENSOR_NAMES.at(arch).at(tensor); } - std::string operator()(llm_tensor tensor, const std::string & suffix) const { + std::string operator()(llm_tensor tensor, const char * suffix) const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix; + return std::string(LLM_TENSOR_NAMES.at(arch).at(tensor)) + "." + suffix; } std::string operator()(llm_tensor tensor, int bid) const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid); + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { + std::string operator()(llm_tensor tensor, const char * suffix, int bid) const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix; + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid) + "." + suffix; } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { + std::string operator()(llm_tensor tensor, const char * suffix, int bid, int xid) const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; + return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor), bid, xid) + "." + suffix; } }; @@ -4916,7 +4916,7 @@ struct llama_model_loader { static const int TENSOR_NOT_REQUIRED = 1; static const int TENSOR_DUPLICATED = 2; - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, int flags = 0) { + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0) { const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { @@ -4926,7 +4926,7 @@ struct llama_model_loader { return create_tensor_for(ctx, cur, flags & TENSOR_DUPLICATED); } - struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::vector & ne, size_t offset, bool required = true) { + struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true) { const struct ggml_tensor * cur = check_tensor_dims(name, ne, required); if (cur == NULL) { @@ -4939,7 +4939,7 @@ struct llama_model_loader { std::array dims; for (size_t i = 0; i < GGML_MAX_DIMS; ++i) { - dims[i] = i < ne.size() ? ne[i] : 1; + dims[i] = i < ne.size() ? ne.begin()[i] : 1; } struct ggml_tensor * tensor = ggml_view_4d(ctx, base, diff --git a/src/unicode-data.cpp b/src/unicode-data.cpp index 02bdf782380fe..07424bbab54cc 100644 --- a/src/unicode-data.cpp +++ b/src/unicode-data.cpp @@ -7,7 +7,7 @@ #include #include -const std::vector> unicode_ranges_flags = { // start, flags // last=next_start-1 +const std::initializer_list> unicode_ranges_flags = { // start, flags // last=next_start-1 {0x000000, 0x0080}, {0x000020, 0x0008}, {0x000021, 0x0020}, @@ -2311,7 +2311,8 @@ const std::unordered_set unicode_set_whitespace = { 0x003000, }; -const std::unordered_map unicode_map_lowercase = { +// list is always in ascending order, to enable binary searh +const std::initializer_list> unicode_map_lowercase = { {0x000041, 0x000061}, {0x000042, 0x000062}, {0x000043, 0x000063}, @@ -3747,7 +3748,8 @@ const std::unordered_map unicode_map_lowercase = { {0x01E921, 0x01E943}, }; -const std::unordered_map unicode_map_uppercase = { +// list is always in ascending order, to enable binary searh +const std::initializer_list> unicode_map_uppercase = { {0x000061, 0x000041}, {0x000062, 0x000042}, {0x000063, 0x000043}, @@ -5200,7 +5202,7 @@ const std::unordered_map unicode_map_uppercase = { {0x01E943, 0x01E921}, }; -const std::vector unicode_ranges_nfd = { // start, last, nfd +const std::initializer_list unicode_ranges_nfd = { // start, last, nfd {0x000000, 0x000000, 0x000000}, {0x0000C0, 0x0000C5, 0x000041}, {0x0000C7, 0x0000C7, 0x000043}, diff --git a/src/unicode-data.h b/src/unicode-data.h index e27fe1770710a..f6973ebd2e350 100644 --- a/src/unicode-data.h +++ b/src/unicode-data.h @@ -13,8 +13,8 @@ struct range_nfd { static const uint32_t MAX_CODEPOINTS = 0x110000; -extern const std::vector> unicode_ranges_flags; +extern const std::initializer_list> unicode_ranges_flags; extern const std::unordered_set unicode_set_whitespace; -extern const std::unordered_map unicode_map_lowercase; -extern const std::unordered_map unicode_map_uppercase; -extern const std::vector unicode_ranges_nfd; +extern const std::initializer_list> unicode_map_lowercase; +extern const std::initializer_list> unicode_map_uppercase; +extern const std::initializer_list unicode_ranges_nfd; diff --git a/src/unicode.cpp b/src/unicode.cpp index f4e941cd15261..50b35bbbc918c 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -123,11 +123,11 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { static std::vector unicode_cpt_flags_array() { std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); - assert (unicode_ranges_flags.front().first == 0); - assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS); + assert (unicode_ranges_flags.begin()[0].first == 0); + assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS); for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) { - const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags - const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags + const auto range_ini = unicode_ranges_flags.begin()[i-1]; // codepoint_ini, flags + const auto range_end = unicode_ranges_flags.begin()[i]; // codepoint_end, flags for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) { cpt_flags[cpt] = range_ini.second; } @@ -597,7 +597,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c std::vector result(cpts.size()); for (size_t i = 0; i < cpts.size(); ++i) { const uint32_t cpt = cpts[i]; - auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1; + auto it = std::upper_bound(unicode_ranges_nfd.begin(), unicode_ranges_nfd.end(), cpt, comp) - 1; result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt; } return result; @@ -639,8 +639,15 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) { } uint32_t unicode_tolower(uint32_t cp) { - auto it = unicode_map_lowercase.find(cp); - return it == unicode_map_lowercase.end() ? cp : it->second; + // binary search + auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp, + [](const std::pair & pair, uint32_t value) { + return pair.first < value; + }); + if (it != unicode_map_lowercase.end() && it->first == cp) { + return it->second; + } + return cp; // Return the original code point if no lowercase mapping is found } std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) {