diff --git a/common/sampling.cpp b/common/sampling.cpp index 5526e075166b5..80cdbae29431a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -57,6 +57,16 @@ struct ring_buffer { return value; } + T pop_back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + pos = (pos + capacity - 1) % capacity; + T value = data[pos]; + sz--; + return value; + } + const T & rat(size_t i) const { if (i >= sz) { throw std::runtime_error("ring buffer: index out of bounds"); @@ -165,6 +175,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.dry_penalty_last_n, + params.dry_base, + params.dry_multiplier, + params.dry_allowed_length, + params.dry_seq_breakers.data(), + params.dry_seq_breakers.size(), params.penalize_nl, params.ignore_eos)); @@ -239,6 +255,19 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) { llama_sampler_reset(gsmpl->chain); } +void gpt_sampler_reset_grmr(struct gpt_sampler * gsmpl) { + llama_sampler_reset(gsmpl->grmr); +} + +void gpt_sampler_reinit_grmr(struct gpt_sampler * gsmpl, const struct llama_model * model, std::string grammar) { + // free first + llama_sampler_free(gsmpl->grmr); + + // reinit + gsmpl->params.grammar = grammar; + gsmpl->grmr = llama_sampler_init_grammar(model, grammar.c_str(), "root"); +} + struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, @@ -313,6 +342,10 @@ llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) return &gsmpl->cur_p; } +std::vector gpt_sampler_get_prev(struct gpt_sampler * gsmpl) { + return gsmpl->prev.to_vector(); +} + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { return gsmpl->prev.rat(0); } @@ -440,4 +473,17 @@ std::vector gpt_sampler_types_from_chars(const std::string & c } return samplers; -} \ No newline at end of file +} + +void gpt_sampler_rollback( + gpt_sampler * gsmpl, + int rollback_num) { + if(rollback_num > gsmpl->prev.size()) { + rollback_num = gsmpl->prev.size(); + } + + // continuously pop the last token + for(int i = 0; i < rollback_num; i++) { + gsmpl->prev.pop_back(); + } +} diff --git a/common/sampling.h b/common/sampling.h index d67d818ce3666..6fb1d60e808fe 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -34,6 +34,11 @@ struct gpt_sampler_params { float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0f = disabled, recommended value: 0.8f + float dry_base = 1.75f; + uint32_t dry_allowed_length = 2; + std::vector dry_seq_breakers; + uint32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate @@ -93,6 +98,8 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl); // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); +void gpt_sampler_reset_grmr(struct gpt_sampler * gsmpl); +void gpt_sampler_reinit_grmr(struct gpt_sampler * gsmpl, const struct llama_model * model, std::string grammar); struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl); // arguments can be nullptr to skip printing @@ -114,6 +121,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // access the internal list of current candidate tokens llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); +std::vector gpt_sampler_get_prev(struct gpt_sampler * gsmpl); // get the last accepted token llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); @@ -128,4 +136,5 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector gpt_sampler_types_from_chars(const std::string & chars); \ No newline at end of file +std::vector gpt_sampler_types_from_chars(const std::string & chars); +void gpt_sampler_rollback(gpt_sampler * gsmpl, int rollback_num); \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index eb5e236db426e..e6d594ec1a5b3 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1130,6 +1130,12 @@ extern "C" { float penalty_repeat, // 1.0 = disabled float penalty_freq, // 0.0 = disabled float penalty_present, // 0.0 = disabled + uint32_t dry_penalty_last_n, + float dry_base, + float dry_multiplier, + float dry_allowed_length, + const llama_token* dry_seq_breakers, + size_t dry_seq_breakers_size, bool penalize_nl, // consider newlines as a repeatable token bool ignore_eos); // ignore the end-of-sequence token diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index c8893962686f0..0cd208d897eca 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -11,6 +11,7 @@ #include #include #include +#include static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { probs.resize(cur_p->size); @@ -433,6 +434,104 @@ void llama_sampler_penalties_impl( cur_p->sorted = false; } +void llama_sampler_dry_impl( + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const llama_token * dry_seq_breakers, + size_t dry_seq_breakers_size) { + // skip dry sampler if we don't have a previous token + if (last_tokens_size < 1) return; + + // get the last token + auto last_token = last_tokens[last_tokens_size - 1]; + + // if last token is part of the sequence breakers, skip whole sampler + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { + return; + } + + // create an unordered map of "next tokens" <-> max match length + std::unordered_map match_lengths; + + // loop through each previous token (exclude the last token) + for (size_t i = 0; i < last_tokens_size - 1; ++i) { + // skip if the compare token is not the same as the last token + if (last_tokens[i] != last_token) { + continue; + } + + // get the next token (i + 1 is always less than last_tokens_size) + auto next_token = last_tokens[i + 1]; + + // if next token is part of the sequence breakers, skip + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + continue; + } + + // try to extend the match backwards (match length starts at 1 because last token is already matched) + size_t match_length = 1; + + // loop through the previous tokens + for (;; match_length++) { + // if we have reached the start of our last tokens, break + if (i < match_length) break; + + // compare token starts at our prev index, going backwards by match length + auto compare_token = last_tokens[i - match_length]; + + // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself + auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + + // break out of the match if any tokens don't match + if (compare_token != head_token) { + break; + } + + // if compare token is part of the sequence breakers, break out of the match + if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { + break; + } + } + + // Check if the next token exists in the map + auto it = match_lengths.find(next_token); + + if (it == match_lengths.end()) { + // Key does not exist, insert the new value + match_lengths[next_token] = match_length; + } else { + // Key exists, update it with the max of the new value or the existing value + it->second = std::max(it->second, match_length); + } + } + + // apply penalties + for (const auto& pair : match_lengths) { + auto next_token = pair.first; + auto match_length = pair.second; + + // if the match length is greater than or equal to our allowed length in config, we apply penalities + if (match_length >= dry_allowed_length) { + + // find our next token in the candidates->data + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].id == next_token) { + // calculate the penalty + float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); + + // apply the dry penalty + candidates->data[i].logit -= penalty; + break; + } + } + } + } +} + // llama_sampler API const char * llama_sampler_name(const struct llama_sampler * smpl) { @@ -1216,6 +1315,12 @@ struct llama_sampler_penalties { const float penalty_freq; const float penalty_present; + const uint32_t dry_penalty_last_n; + const float dry_base; + const float dry_multiplier; + const float dry_allowed_length; + std::vector dry_seq_breakers; + const bool penalize_nl; const bool ignore_eos; @@ -1286,8 +1391,20 @@ static struct llama_sampler_i llama_sampler_penalties_i = { token_count[ctx->prev.rat(i)]++; } + // apply repetition, frequency, and presence penalties llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + // make the ring buffer of last tokens into a vector + auto last_tokens = ctx->prev.to_vector(); + + // take the last n tokens from the ring buffer + if (last_tokens.size() > (size_t) ctx->dry_penalty_last_n) { + last_tokens.erase(last_tokens.begin(), last_tokens.end() - ctx->penalty_last_n); + } + + // apply DRY penalty + llama_sampler_dry_impl(cur_p, last_tokens.data(), last_tokens.size(), ctx->dry_base, ctx->dry_multiplier, ctx->dry_allowed_length, ctx->dry_seq_breakers.data(), ctx->dry_seq_breakers.size()); + if (!ctx->penalize_nl && nl_found) { // restore the logit of the newline token if it was penalized cur_p->data[nl_idx].logit = nl_logit; @@ -1307,6 +1424,12 @@ static struct llama_sampler_i llama_sampler_penalties_i = { ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present, + ctx->dry_penalty_last_n, + ctx->dry_base, + ctx->dry_multiplier, + ctx->dry_allowed_length, + ctx->dry_seq_breakers.data(), + ctx->dry_seq_breakers.size(), ctx->penalize_nl, ctx->ignore_eos); @@ -1332,6 +1455,12 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_repeat, float penalty_freq, float penalty_present, + uint32_t dry_penalty_last_n, + float dry_base, + float dry_multiplier, + float dry_allowed_length, + const llama_token* dry_seq_breakers, + size_t dry_seq_breakers_size, bool penalize_nl, bool ignore_eos) { if (linefeed_id == LLAMA_TOKEN_NULL) { @@ -1352,6 +1481,11 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_base = */ dry_base, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_seq_breakers = */ std::vector(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size), /* .penalize_nl = */ penalize_nl, /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n),