From 2fe2762cdffdd37e532a23ed43aa714165e0c847 Mon Sep 17 00:00:00 2001 From: Lei Wen Date: Fri, 3 Nov 2023 13:44:43 +0800 Subject: [PATCH] fix: http-api-binding also need to read model prompt definition If don't use model's prompt template, the output would not be expected. Signed-off-by: Lei Wen --- crates/http-api-bindings/README.md | 14 ++------------ crates/http-api-bindings/src/fastchat.rs | 9 ++------- crates/http-api-bindings/src/lib.rs | 5 +++-- crates/http-api-bindings/src/vertex_ai.rs | 4 ---- crates/tabby/src/serve/engine.rs | 11 +++-------- 5 files changed, 10 insertions(+), 33 deletions(-) diff --git a/crates/http-api-bindings/README.md b/crates/http-api-bindings/README.md index 8664710a2bc0..3e6870e5b0ba 100644 --- a/crates/http-api-bindings/README.md +++ b/crates/http-api-bindings/README.md @@ -1,14 +1,3 @@ -## Examples - -```bash -export MODEL_ID="code-gecko" -export PROJECT_ID="$(gcloud config get project)" -export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict" -export AUTHORIZATION="Bearer $(gcloud auth print-access-token)" - -cargo run --example simple -``` - ## Usage ```bash @@ -16,6 +5,7 @@ export MODEL_ID="code-gecko" export PROJECT_ID="$(gcloud config get project)" export API_ENDPOINT="https://us-central1-aiplatform.googleapis.com/v1/projects/${PROJECT_ID}/locations/us-central1/publishers/google/models/${MODEL_ID}:predict" export AUTHORIZATION="Bearer $(gcloud auth print-access-token)" +export TABBY_CONFIG=/the_dir_where_your_model_located/tabby.json -cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\"}" +cargo run serve --device experimental-http --model "{\"kind\": \"vertex-ai\", \"api_endpoint\": \"$API_ENDPOINT\", \"authorization\": \"$AUTHORIZATION\", \"tabby_config\": \"$TABBY_CONFIG\"}" ``` diff --git a/crates/http-api-bindings/src/fastchat.rs b/crates/http-api-bindings/src/fastchat.rs index fe0839dc8c64..2a4992c0d6e1 100644 --- a/crates/http-api-bindings/src/fastchat.rs +++ b/crates/http-api-bindings/src/fastchat.rs @@ -8,7 +8,7 @@ use tabby_inference::{helpers, TextGeneration, TextGenerationOptions}; #[derive(Serialize)] struct Request { model: String, - prompt: Vec, + prompt: String, max_tokens: usize, temperature: f32, } @@ -49,19 +49,14 @@ impl FastChatEngine { client, } } - - pub fn prompt_template() -> String { - "{prefix}{suffix}".to_owned() - } } #[async_trait] impl TextGeneration for FastChatEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let tokens: Vec<&str> = prompt.split("").collect(); let request = Request { model: self.model_name.to_owned(), - prompt: vec![tokens[0].to_owned()], + prompt: prompt.to_string(), max_tokens: options.max_decoding_length, temperature: options.sampling_temperature, }; diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index fc743cada0ed..cadf247d9e9b 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -9,6 +9,7 @@ use vertex_ai::VertexAIEngine; pub fn create(model: &str) -> (Box, String) { let params = serde_json::from_str(model).expect("Failed to parse model string"); let kind = get_param(¶ms, "kind"); + let metafile = get_param(¶ms, "tabby_config"); if kind == "vertex-ai" { let api_endpoint = get_param(¶ms, "api_endpoint"); let authorization = get_param(¶ms, "authorization"); @@ -16,7 +17,7 @@ pub fn create(model: &str) -> (Box, String) { api_endpoint.as_str(), authorization.as_str(), )); - (engine, VertexAIEngine::prompt_template()) + (engine, metafile) } else if kind == "fastchat" { let model_name = get_param(¶ms, "model_name"); let api_endpoint = get_param(¶ms, "api_endpoint"); @@ -26,7 +27,7 @@ pub fn create(model: &str) -> (Box, String) { model_name.as_str(), authorization.as_str(), )); - (engine, FastChatEngine::prompt_template()) + (engine, metafile) } else { panic!("Only vertex_ai and fastchat are supported for http backend"); } diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index 89b79f6026ff..6e7809f3ca7a 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -57,10 +57,6 @@ impl VertexAIEngine { client, } } - - pub fn prompt_template() -> String { - "{prefix}{suffix}".to_owned() - } } #[async_trait] diff --git a/crates/tabby/src/serve/engine.rs b/crates/tabby/src/serve/engine.rs index 7ee767960d7f..54d3c4d4430d 100644 --- a/crates/tabby/src/serve/engine.rs +++ b/crates/tabby/src/serve/engine.rs @@ -33,14 +33,9 @@ pub async fn create_engine( ) } } else { - let (engine, prompt_template) = http_api_bindings::create(model_id); - ( - engine, - EngineInfo { - prompt_template: Some(prompt_template), - chat_template: None, - }, - ) + let (engine, metafile) = http_api_bindings::create(model_id); + let engine_info = EngineInfo::read(PathBuf::from(metafile)); + (engine, engine_info) } }