Skip to content

Commit

Permalink
update datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
MaybeShewill-CV committed Dec 25, 2024
1 parent 9cd4a04 commit 0902577
Showing 1 changed file with 77 additions and 6 deletions.
83 changes: 77 additions & 6 deletions src/models/llm/llm_datatype.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ struct ModelStatus {
int32_t embed_dims;
};

struct ChatMessage {
std::string role;
std::string content;
};

class Dialog {
public:
/***
Expand Down Expand Up @@ -52,15 +57,15 @@ class Dialog {
*
* @param msg
*/
explicit Dialog(const llama_chat_message &msg) { messages.push_back(msg); }
explicit Dialog(const ChatMessage &msg) { messages.push_back(msg); }

/***
*
* @param role
* @param content
*/
Dialog(const std::string &role, const std::string &content) {
llama_chat_message msg = {role.c_str(), content.c_str()};
ChatMessage msg = {role, content};
messages.push_back(msg);
}

Expand All @@ -70,7 +75,7 @@ class Dialog {
* @param content
*/
Dialog(const char* role, const char* content) {
llama_chat_message msg = {role, content};
ChatMessage msg = {role, content};
messages.push_back(msg);
}

Expand Down Expand Up @@ -100,13 +105,13 @@ class Dialog {
* @param index
* @return
*/
inline llama_chat_message &operator[](size_t index) { return messages[index]; }
inline ChatMessage &operator[](size_t index) { return messages[index]; }

/***
*
* @param msg
*/
inline void push_back(const llama_chat_message &msg) { messages.push_back(msg); }
inline void push_back(const ChatMessage &msg) { messages.push_back(msg); }

/***
*
Expand All @@ -126,9 +131,75 @@ class Dialog {
inline size_t size() const { return messages.size(); }

public:
std::vector<llama_chat_message> messages;
std::vector<ChatMessage> messages;
};

namespace llama {

enum common_sampler_type {
COMMON_SAMPLER_TYPE_NONE = 0,
COMMON_SAMPLER_TYPE_DRY = 1,
COMMON_SAMPLER_TYPE_TOP_K = 2,
COMMON_SAMPLER_TYPE_TOP_P = 3,
COMMON_SAMPLER_TYPE_MIN_P = 4,
//COMMON_SAMPLER_TYPE_TFS_Z = 5,
COMMON_SAMPLER_TYPE_TYPICAL_P = 6,
COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9,
};

struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler

int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float xtc_probability = 0.00f; // 0.0 = disabled
float xtc_threshold = 0.10f; // > 0.5 disables XTC
float typ_p = 1.00f; // typical_p, 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;

std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY


std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE,
};

std::string grammar; // optional BNF-like grammar to constrain sampling

std::vector<llama_logit_bias> logit_bias; // logit biases to apply
};

}

}
}
}
Expand Down

0 comments on commit 0902577

Please sign in to comment.