Skip to content

Commit

Permalink
Fork runner for more reliable CPU ProstT5
Browse files Browse the repository at this point in the history
  • Loading branch information
milot-mirdita committed Jan 10, 2025
1 parent 3b76601 commit 7d84f0d
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 272 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ if(ENABLE_PROSTT5)
set(GGML_STATIC ON)
set(BUILD_SHARED_LIBS OFF)
set(GGML_BLAS OFF)
set(GGML_OPENMP OFF)
if (NOT NATIVE_ARCH)
set(GGML_NATIVE OFF)
if (HAVE_AVX2)
Expand Down
279 changes: 68 additions & 211 deletions src/strucclustutils/ProstT5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,14 @@ static char number_to_char(unsigned int n) {
static int encode(llama_context * ctx, std::vector<llama_token> & enc_input, std::string & result) {
const struct llama_model * model = llama_get_model(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
// llama_kv_cache_clear(ctx);
// llama_set_embeddings(ctx, true);
// run model
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
if (llama_encode(ctx, llama_batch_get_one(enc_input.data(), enc_input.size())) < 0) {
// LOG_ERR("%s : failed to encode\n", __func__);
return 1;
}
} else {
// LOG_ERR("%s : no encoder\n", __func__);
if (llama_encode(ctx, llama_batch_get_one(enc_input.data(), enc_input.size())) < 0) {
// LOG_ERR("%s : failed to encode\n", __func__);
return 1;
}
// Log the embeddings (assuming n_embd is the embedding size per token)

// LOG_INF("%s: n_tokens = %zu, n_seq = %d\n", __func__, enc_input.size(), 1);
float* embeddings = llama_get_embeddings(ctx);
if (embeddings == nullptr) {
// LOG_ERR("%s : failed to retrieve embeddings\n", __func__);
return 1;
}
int * arg_max_idx = new int[enc_input.size()];
Expand All @@ -69,8 +59,8 @@ static int encode(llama_context * ctx, std::vector<llama_token> & enc_input, std
for (int i = 0; i < seq_len - 1; ++i) {
result.push_back(number_to_char(arg_max_idx[i]));
}
delete [] arg_max_idx;
delete [] arg_max;
delete[] arg_max_idx;
delete[] arg_max;
return 0;
}

Expand Down Expand Up @@ -110,182 +100,18 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
return devices;
}

struct lora_adapter_info {
std::string path;
float scale;
};

struct lora_adapter_container : lora_adapter_info {
struct llama_lora_adapter * adapter;
};

struct init_result {
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<lora_adapter_container> lora_adapters;
};

struct cpu_params {
int n_threads = -1;
bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask.
bool mask_valid = false; // Default: any CPU
enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime)
bool strict_cpu = false; // Use strict CPU placement
uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling)
};

struct common_params {
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_parallel = 1; // number of parallel sequences to decode
// float rope_freq_base = 0.0f; // RoPE base frequency
// float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
// float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
// float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
// float yarn_beta_fast = 32.0f; // YaRN low correction dim
// float yarn_beta_slow = 1.0f; // YaRN high correction dim
// int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = 0.1f; // KV cache defragmentation threshold

// // offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading

int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs

enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;

ggml_backend_sched_eval_callback cb_eval = nullptr;
void * cb_eval_user_data = nullptr;

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings

std::string model = ""; // model path // NOLINT
std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT

bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

bool embedding = true; // get only sentence embedding
};

static struct init_result init_from_params(common_params & params) {
init_result iparams;
auto mparams = llama_model_default_params();

if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
mparams.n_gpu_layers = params.n_gpu_layers;
mparams.kv_overrides = NULL;

llama_model * model = nullptr;

model = llama_load_model_from_file(params.model.c_str(), mparams);

if (model == NULL) {
// LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
return iparams;
}

auto cparams = llama_context_default_params();

cparams.n_ctx = params.n_ctx;
cparams.n_seq_max = params.n_parallel;
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
// cparams.rope_scaling_type = params.rope_scaling_type;
// cparams.rope_freq_base = params.rope_freq_base;
// cparams.rope_freq_scale = params.rope_freq_scale;
// cparams.yarn_ext_factor = params.yarn_ext_factor;
// cparams.yarn_attn_factor = params.yarn_attn_factor;
// cparams.yarn_beta_fast = params.yarn_beta_fast;
// cparams.yarn_beta_slow = params.yarn_beta_slow;
// cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;

cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
// LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
return iparams;
}


// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
// LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
return iparams;
}
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
llama_lora_adapter_clear(lctx);
for (auto & la : iparams.lora_adapters) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(lctx, la.adapter, la.scale);
}
}
}
// struct lora_adapter_info {
// std::string path;
// float scale;
// };

iparams.model = model;
iparams.context = lctx;
// struct lora_adapter_container : lora_adapter_info {
// struct llama_lora_adapter* adapter;
// };

return iparams;
}
// struct init_result {
// std::vector<lora_adapter_container> lora_adapters;
// };

LlamaInitGuard::LlamaInitGuard(bool verbose) {
if (!verbose) {
Expand All @@ -299,57 +125,88 @@ LlamaInitGuard::~LlamaInitGuard() {
llama_backend_free();
}

ProstT5::ProstT5(const std::string& model_file, std::string & device) {
common_params params;
params.n_ubatch = params.n_batch;
params.warmup = false;
params.model = model_file;
params.cpuparams.n_threads = 1;
params.use_mmap = true;
params.devices = parse_device_list(device);
ProstT5Model::ProstT5Model(const std::string& model_file, std::string& device) {
auto mparams = llama_model_default_params();
std::vector<ggml_backend_dev_t> devices = parse_device_list(device);
if (!devices.empty()) {
mparams.devices = devices.data();
}

int gpus = 0;
for (const auto& dev : params.devices) {
for (const auto& dev : devices) {
if (!dev) {
continue;
}
gpus += ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU;
}
if (gpus > 0) {
params.n_gpu_layers = 24;
mparams.n_gpu_layers = 24;
} else {
params.n_gpu_layers = 0;
}
mparams.n_gpu_layers = 0;
}
mparams.use_mmap = true;
model = llama_load_model_from_file(model_file.c_str(), mparams);

// for (auto & la : params.lora_adapters) {
// lora_adapter_container loaded_la;
// loaded_la.path = la.path;
// loaded_la.scale = la.scale;
// loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
// if (loaded_la.adapter == nullptr) {
// llama_free_model(model);
// return;
// }
// lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
// }
}

// load the model
init_result llama_init = init_from_params(params);
ProstT5Model::~ProstT5Model() {
llama_free_model(model);
}

model = llama_init.model;
ctx = llama_init.context;
ProstT5::ProstT5(ProstT5Model& model, int threads) : model(model) {
auto cparams = llama_context_default_params();
cparams.n_threads = threads;
cparams.n_threads_batch = threads;
cparams.n_ubatch = 4096;
cparams.n_batch = 4096;
cparams.n_ctx = 4096;
cparams.embeddings = true;
cparams.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL;

ctx = llama_new_context_with_model(model.model, cparams);
// batch = llama_batch_init(4096, 0, 1);
// if (!params.lora_init_without_apply) {
// llama_lora_adapter_clear(lctx);
// for (auto & la : iparams.lora_adapters) {
// if (la.scale != 0.0f) {
// llama_lora_adapter_set(lctx, la.adapter, la.scale);
// }
// }
// }
};

ProstT5::~ProstT5() {
llama_free(ctx);
llama_free_model(model);
}

std::string ProstT5::predict(const std::string& aa) {
std::string result;
std::vector<llama_token> embd_inp;
embd_inp.reserve(aa.length() + 2);
embd_inp.emplace_back(llama_token_get_token(model, "<AA2fold>"));
llama_token unk_aa = llama_token_get_token(model, "▁X");
embd_inp.emplace_back(llama_token_get_token(model.model, "<AA2fold>"));
llama_token unk_aa = llama_token_get_token(model.model, "▁X");
for (size_t i = 0; i < aa.length(); ++i) {
std::string current_char("");
current_char.append(1, toupper(aa[i]));
llama_token token = llama_token_get_token(model, current_char.c_str());
llama_token token = llama_token_get_token(model.model, current_char.c_str());
if (token == LLAMA_TOKEN_NULL) {
embd_inp.emplace_back(unk_aa);
} else {
embd_inp.emplace_back(token);
}
}
embd_inp.emplace_back(llama_token_get_token(model, "</s>"));

embd_inp.emplace_back(llama_token_get_token(model.model, "</s>"));
encode(ctx, embd_inp, result);
return result;
}
Expand Down
14 changes: 11 additions & 3 deletions src/strucclustutils/ProstT5.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@ class LlamaInitGuard {
LlamaInitGuard& operator=(const LlamaInitGuard&) = delete;
};

class ProstT5Model {
public:
ProstT5Model(const std::string& model_file, std::string& device);
~ProstT5Model();

llama_model* model;
};

class ProstT5 {
public:
ProstT5(const std::string& model_file, std::string & device);
ProstT5(ProstT5Model& model, int threads);
~ProstT5();

static std::vector<std::string> getDevices();

std::string predict(const std::string& aa);
void perf();

llama_model * model;
llama_context * ctx;
ProstT5Model& model;
llama_context* ctx;
};


Expand Down
Loading

0 comments on commit 7d84f0d

Please sign in to comment.