Skip to content

Commit

Permalink
feat: switch cuda backend to llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Oct 27, 2023
1 parent 308681e commit 90d6aa8
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh

- name: Bulid release binary
run: cargo build --no-default-features --release --target ${{ matrix.target }} --package tabby
run: cargo build --release --target ${{ matrix.target }} --package tabby

- name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

## Fixes and Improvements

* Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/TODO
* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637

Expand Down
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2 as source
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=11.7.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the CUDA runtime image
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}

ENV CTRANSLATE2_ROOT=/opt/ctranslate2
COPY --from=source $CTRANSLATE2_ROOT $CTRANSLATE2_ROOT
FROM ${BASE_CUDA_DEV_CONTAINER} as build

ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
Expand Down Expand Up @@ -30,10 +34,10 @@ RUN mkdir -p target

RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/root/workspace/target \
cargo build --features link_shared --release && \
cargo build --features cuda --release && \
cp target/release/tabby /opt/tabby/bin/

FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2
FROM ${BASE_CUDA_RUN_CONTAINER} as runtime

RUN apt-get update && \
apt-get install -y --no-install-recommends \
Expand Down
3 changes: 3 additions & 0 deletions crates/llama-cpp-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "llama-cpp-bindings"
version = "0.5.0-dev"
edition = "2021"

[features]
cuda = []

[build-dependencies]
cxx-build = "1.0"
cmake = "0.1"
Expand Down
20 changes: 10 additions & 10 deletions crates/llama-cpp-bindings/build.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use cmake::Config;

fn main() {
let mut config = Config::new("llama.cpp");
if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
}
let dst = config.build();

println!("cargo:rerun-if-changed=cc/*.h");
println!("cargo:rerun-if-changed=cc/*.cc");

println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");

let mut config = Config::new("llama.cpp");
if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
}
if cfg!(feature = "cuda") {
config.define("LLAMA_CUBLAS", "ON");
}

let dst = config.build();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");

cxx_build::bridge("src/lib.rs")
.file("src/engine.cc")
Expand Down
8 changes: 3 additions & 5 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "tabby"
version = "0.5.0-dev"
edition = "2021"

[features]
cuda = ["llama-cpp-bindings/cuda"]

[dependencies]
tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler" }
Expand Down Expand Up @@ -43,7 +46,6 @@ textdistance = "1.0.2"
regex.workspace = true
thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
ctranslate2-bindings = { path = "../ctranslate2-bindings", optional = true }

[dependencies.uuid]
version = "1.3.3"
Expand All @@ -53,10 +55,6 @@ features = [
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
]

[features]
link_shared = ["ctranslate2-bindings/link_shared"]
link_cuda_static = ["ctranslate2-bindings"]

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

Expand Down
44 changes: 1 addition & 43 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn create_engine(
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
let engine = create_ggml_engine(&args.device, &model_dir);
(
engine,
EngineInfo {
Expand All @@ -38,48 +38,6 @@ pub struct EngineInfo {
pub chat_template: Option<String>,
}

#[cfg(not(any(feature = "link_shared", feature = "link_cuda_static")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
_metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ggml_engine(&args.device, model_dir)
}

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device.use_ggml_backend() {
create_ggml_engine(&args.device, model_dir)
} else {
create_ctranslate2_engine(args, model_dir, metadata)
}
}

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};

let device = format!("{}", args.device);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}

fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_v2_file())
Expand Down
29 changes: 7 additions & 22 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub enum Device {
#[strum(serialize = "cpu")]
Cpu,

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
#[cfg(feature = "cuda")]
Cuda,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
Expand All @@ -86,24 +86,14 @@ pub enum Device {
}

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Metal || *self == Device::Cpu
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Cpu
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
#[cfg(feature="cuda")]
fn ggml_use_gpu(&self) -> bool {
false
*self == Device::Cuda
}
}

Expand Down Expand Up @@ -141,9 +131,9 @@ pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

if args.device != Device::ExperimentalHttp {
download_model(&args.model, &args.device).await;
download_model(&args.model).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, &args.device).await;
download_model(chat_model).await;
}
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
Expand Down Expand Up @@ -285,15 +275,10 @@ fn start_heartbeat(args: &ServeArgs) {
});
}

async fn download_model(model: &str, device: &Device) {
async fn download_model(model: &str) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = if device.use_ggml_backend() {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};

let download_result = downloader.download_ggml_files().await;
download_result.unwrap_or_else(handler);
}

Expand Down

0 comments on commit 90d6aa8

Please sign in to comment.