Skip to content

Commit

Permalink
feat: support cont batching
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Oct 29, 2023
1 parent 2a23daf commit d2138b8
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 103 deletions.
8 changes: 4 additions & 4 deletions crates/llama-cpp-bindings/include/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();

virtual void start(rust::Slice<const uint32_t> input_token_ids) = 0;
virtual uint32_t step() = 0;
virtual void end() = 0;
virtual void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) = 0;
virtual void stop_request(uint32_t request_id) = 0;
virtual rust::Vec<uint32_t> step() = 0;

virtual uint32_t eos_token() const = 0;
virtual uint32_t eos_token_id() const = 0;
};

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path);
Expand Down
183 changes: 135 additions & 48 deletions crates/llama-cpp-bindings/src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <functional>
#include <vector>
#include <deque>
#include <unordered_set>

#include <ggml.h>
#include <llama.h>
Expand All @@ -10,8 +12,34 @@ namespace llama {
TextInferenceEngine::~TextInferenceEngine() {}

namespace {
static size_t N_BATCH = 512; // # per batch inference.
static size_t N_CTX = 4096; // # max kv history.
int get_parallelism() {
const char* parallelism = std::getenv("LLAMA_CPP_PARALLELISM");
if (parallelism) {
return std::stoi(parallelism);
} else {
return 4;
}
}

static size_t N_CONCURRENT_REQUESTS = get_parallelism();
//
constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.

struct Request {
Request(size_t request_id, rust::Slice<const uint32_t> input_token_ids) :
id(request_id),
tokens(input_token_ids.begin(), input_token_ids.end()) {
}

size_t id = -1;
llama_seq_id seq_id = -1;

std::vector<llama_token> tokens;
size_t i_batch = -1;
size_t n_past = 0;
};


template<class T>
using owned = std::unique_ptr<T, std::function<void(T*)>>;
Expand All @@ -21,78 +49,136 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx) :
model_(std::move(model)),
ctx_(std::move(ctx)) {
batch_ = llama_batch_init(N_BATCH, 0, 1);
batch_ = llama_batch_init(N_CTX * N_CONCURRENT_REQUESTS, 0, 1);
}

~TextInferenceEngineImpl() override {
~TextInferenceEngineImpl() {
llama_batch_free(batch_);
}

void start(rust::Slice<const uint32_t> input_token_ids) override {
auto* ctx = ctx_.get();
llama_reset_timings(ctx);
std::vector<llama_token> tokens_list(input_token_ids.begin(), input_token_ids.end());

for (size_t i = 0; i < tokens_list.size(); i += N_BATCH) {
const size_t size = std::min(N_BATCH, tokens_list.size() - i);
eval(tokens_list.data() + i, size, /* reset = */ i == 0);
}
void add_request(uint32_t request_id, rust::Slice<const uint32_t> input_token_ids) override {
pending_requests_.push_back(Request(request_id, input_token_ids));
}

uint32_t step() override {
const llama_token id = sample();
eval(const_cast<llama_token*>(&id), 1, /* reset = */ false);
return id;
void stop_request(uint32_t request_id) override {
stopped_requests_.insert(request_id);
}

void end() override {
llama_print_timings(ctx_.get());
}
rust::Vec<uint32_t> step() override {
auto* ctx = ctx_.get();
auto n_vocab = llama_n_vocab(llama_get_model(ctx));

uint32_t eos_token() const override {
return llama_token_eos(llama_get_model(ctx_.get()));
}
// Remove stopped requests.
if (!stopped_requests_.empty()) {
std::vector<Request> requests;
for (auto& request : requests_) {
if (stopped_requests_.count(request.id) > 0) {
// Release KV cache.
llama_kv_cache_seq_rm(ctx_.get(), request.id, -1, -1);
} else {
requests.emplace_back(request);
}
}

requests_ = requests;
}

private:
uint32_t sample() const {
auto* ctx = ctx_.get();
// Add pending requests.
while (pending_requests_.size() > 0 && requests_.size() < N_CONCURRENT_REQUESTS) {
Request request = std::move(pending_requests_.front());
pending_requests_.pop_front();

auto logits = llama_get_logits_ith(ctx, batch_.n_tokens - 1);
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Ignore stopped pending requests.
if (stopped_requests_.count(request.id) > 0) {
continue;
}

// Greedy sampling (always select the highest logit).
return std::distance(logits, std::max_element(logits, logits + n_vocab));
}
requests_.push_back(request);
}

// Clear stopped requests.
stopped_requests_.clear();

void eval(llama_token* data, size_t size, bool reset) {
if (reset) {
n_past_ = 0;
if (requests_.size() == 0) {
return {};
}

batch_.n_tokens = size;
for (size_t i = 0; i < size; ++i) {
batch_.token[i] = data[i];
batch_.pos[i] = n_past_ + i;
batch_.n_seq_id[i] = 1;
batch_.seq_id[i][0] = 0;
batch_.logits[i] = false;
// Clear the batch.
batch_.n_tokens = 0;

// Insert tokens from ongoing requests to batch.
for (auto& request : requests_) {
const size_t n_tokens = batch_.n_tokens;
for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i];
batch_.pos[n_tokens + i] = request.n_past + i;
batch_.n_seq_id[n_tokens + i] = 1;
batch_.seq_id[n_tokens + i][0] = request.id;
batch_.logits[n_tokens + i] = false;
}
batch_.n_tokens += request.tokens.size();

batch_.logits[batch_.n_tokens - 1] = true;
request.i_batch = batch_.n_tokens - 1;
}
batch_.logits[size - 1] = true;

auto* ctx = ctx_.get();
llama_kv_cache_tokens_rm(ctx, n_past_, -1);
if (llama_decode(ctx, batch_)) {
throw std::runtime_error("Failed to eval");
rust::Vec<uint32_t> result;
result.reserve(requests_.size() * 2);

// Decode tokens in chunks
for (size_t i = 0; i < static_cast<size_t>(batch_.n_tokens); i += N_BATCH) {
const int32_t n_tokens = std::min(N_BATCH, batch_.n_tokens - i);
llama_batch batch_view = {
n_tokens,
batch_.token + i,
nullptr,
batch_.pos + i,
batch_.n_seq_id + i,
batch_.seq_id + i,
batch_.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
throw std::runtime_error("Failed to eval");
}

for (auto& request : requests_) {
if ((request.i_batch < i) || (request.i_batch >= (i + n_tokens))) {
continue;
}

int32_t i_batch = request.i_batch - i;
auto logits = llama_get_logits_ith(ctx, i_batch);
auto next_token = std::distance(logits, std::max_element(logits, logits + n_vocab));

request.n_past += request.tokens.size();

request.tokens.clear();
request.tokens.push_back(next_token);

result.push_back(request.id);
result.push_back(next_token);
}
}

n_past_ += size;
return result;
}

uint32_t eos_token_id() const override {
return llama_token_eos(llama_get_model(ctx_.get()));
}

size_t n_past_;
private:
owned<llama_model> model_;
owned<llama_context> ctx_;

llama_batch batch_;

std::vector<Request> requests_;
std::deque<Request> pending_requests_;
std::unordered_set<uint32_t> stopped_requests_;
};

static int g_llama_cpp_log_level = 0;
Expand All @@ -117,6 +203,7 @@ struct BackendInitializer {
llama_backend_free();
}
};

} // namespace

std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path) {
Expand Down
Loading

0 comments on commit d2138b8

Please sign in to comment.