Skip to content

Commit

Permalink
feat: add --parallelism to control throughput and vram usage (#727)
Browse files Browse the repository at this point in the history
* feat: add --parallelism to control throughput and vram usage

* update default

* Revert "update default"

This reverts commit 349792c.

* cargo fmt
  • Loading branch information
wsxiaoys authored Nov 8, 2023
1 parent 3fb8445 commit 8ab35b2
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 37 deletions.
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0;
};

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
} // namespace
29 changes: 11 additions & 18 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}

namespace {
int get_parallelism() {
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
if (parallelism) {
return std::stoi(parallelism);
} else {
return 4;
}
}

static size_t N_CONCURRENT_REQUESTS = get_parallelism();

constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.

Expand Down Expand Up @@ -95,10 +84,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;

class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
ctx_(std::move(ctx)),
parallelism_(parallelism) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up
{
batch_.n_tokens = 16;
Expand Down Expand Up @@ -155,7 +145,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
}

// Add pending requests.
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
while (pending_requests_.size() > 0 && requests_.size() < parallelism_) {
Request request = std::move(pending_requests_.front());
pending_requests_.pop_front();

Expand Down Expand Up @@ -283,6 +273,8 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::vector<Request> requests_;
std::deque<Request> pending_requests_;
std::unordered_set<uint32_t> stopped_requests_;

uint32_t parallelism_;
};

static int g_llama_cpp_log_level = 0;
Expand Down Expand Up @@ -310,7 +302,7 @@ struct BackendInitializer {

} // namespace

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
static BackendInitializer initializer;

llama_model_params model_params = llama_model_default_params();
Expand All @@ -322,13 +314,14 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
}

llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = N_CTX * N_CONCURRENT_REQUESTS;
ctx_params.n_ctx = N_CTX * parallelism;
ctx_params.n_batch = N_BATCH;
llama_context* ctx = llama_new_context_with_model(model, ctx_params);

return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free)
owned<llama_context>(ctx, llama_free),
parallelism
);
}

Expand Down
9 changes: 7 additions & 2 deletions crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ mod ffi {

type TextInferenceEngine;

fn create_engine(use_gpu: bool, model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn create_engine(
use_gpu: bool,
model_path: &str,
parallelism: u8,
) -> UniquePtr<TextInferenceEngine>;

fn add_request(
self: Pin<&mut TextInferenceEngine>,
Expand All @@ -43,6 +47,7 @@ unsafe impl Sync for ffi::TextInferenceEngine {}
pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
parallelism: u8,
}

pub struct LlamaTextGeneration {
Expand All @@ -52,7 +57,7 @@ pub struct LlamaTextGeneration {

impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path);
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}
Expand Down
16 changes: 12 additions & 4 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ pub async fn create_engine(
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine =
create_ggml_engine(&args.device, model_path.display().to_string().as_str());
let engine = create_ggml_engine(
&args.device,
model_path.display().to_string().as_str(),
args.parallelism,
);
let engine_info = EngineInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(&args.device, &model_path);
let engine = create_ggml_engine(&args.device, &model_path, args.parallelism);
(
engine,
EngineInfo {
Expand Down Expand Up @@ -57,10 +60,15 @@ impl EngineInfo {
}
}

fn create_ggml_engine(device: &super::Device, model_path: &str) -> Box<dyn TextGeneration> {
fn create_ggml_engine(
device: &super::Device,
model_path: &str,
parallelism: u8,
) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())
.parallelism(parallelism)
.build()
.unwrap();

Expand Down
16 changes: 4 additions & 12 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,13 @@ pub struct ServeArgs {
#[clap(long, default_value_t=Device::Cpu)]
device: Device,

/// DEPRECATED: Do not use.
#[deprecated(since = "0.5.0")]
#[clap(long, hide(true))]
device_indices: Vec<i32>,
/// Parallelism for model serving - increasing this number will have a significant impact on the
/// memory requirement e.g., GPU vRAM.
#[clap(long, default_value_t = 1)]
parallelism: u8,
}

pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

if args.device != Device::ExperimentalHttp {
if fs::metadata(&args.model).is_ok() {
info!("Loading model from local path {}", &args.model);
Expand Down Expand Up @@ -252,12 +250,6 @@ async fn api_router(args: &ServeArgs, config: &Config) -> Router {
.layer(opentelemetry_tracing_layer())
}

fn valid_args(args: &ServeArgs) {
if !args.device_indices.is_empty() {
warn!("--device-indices is deprecated and will be removed in future release.");
}
}

fn start_heartbeat(args: &ServeArgs) {
let state = HealthState::new(args);
tokio::spawn(async move {
Expand Down

0 comments on commit 8ab35b2

Please sign in to comment.