diff --git a/Cargo.lock b/Cargo.lock index 97f9f660d162..63d94f5b83b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3185,6 +3185,7 @@ dependencies = [ "derive_builder", "futures", "regex", + "tabby-common", "tokenizers", ] diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 8cb1a62ac5a0..8b8d7ede05b3 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine { let decoding = self.decoding_factory.create_incremental_decoding( self.tokenizer.clone(), truncate_tokens(encoding.get_ids(), options.max_input_length), - options.stop_words, + options.language, ); let cancel = CancellationToken::new(); diff --git a/crates/http-api-bindings/src/fastchat.rs b/crates/http-api-bindings/src/fastchat.rs index f71e048ca6dc..a2bf97a73404 100644 --- a/crates/http-api-bindings/src/fastchat.rs +++ b/crates/http-api-bindings/src/fastchat.rs @@ -58,9 +58,6 @@ impl FastChatEngine { #[async_trait] impl TextGeneration for FastChatEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { - let _stop_sequences: Vec = - options.stop_words.iter().map(|x| x.to_string()).collect(); - let tokens: Vec<&str> = prompt.split("").collect(); let request = Request { model: self.model_name.to_owned(), diff --git a/crates/http-api-bindings/src/vertex_ai.rs b/crates/http-api-bindings/src/vertex_ai.rs index 1d74b59b41d8..89b79f6026ff 100644 --- a/crates/http-api-bindings/src/vertex_ai.rs +++ b/crates/http-api-bindings/src/vertex_ai.rs @@ -67,7 +67,8 @@ impl VertexAIEngine { impl TextGeneration for VertexAIEngine { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String { let stop_sequences: Vec = options - .stop_words + .language + .get_stop_words() .iter() .map(|x| x.to_string()) // vertex supports at most 5 stop sequence. diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 8e8e29426879..084280b62c95 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine { let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length); engine.as_mut().start(input_token_ids); - let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words); + let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language); let mut n_remains = options.max_decoding_length ; while n_remains > 0 { let Ok(next_token_id) = engine.as_mut().step() else { diff --git a/crates/tabby-common/assets/languages.toml b/crates/tabby-common/assets/languages.toml new file mode 100644 index 000000000000..b13fcd609d26 --- /dev/null +++ b/crates/tabby-common/assets/languages.toml @@ -0,0 +1,53 @@ +[[config]] +languages = ["python"] +line_comment = "#" +top_level_keywords = ["def", "from", "class", "import"] + +[[config]] +languages = ["rust"] +line_comment = "//" +top_level_keywords = [ + "fn", + "trait", + "impl", + "enum", + "pub", + "extern", + "static", + "trait", + "unsafe", + "use", +] + +[[config]] +languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"] +line_comment = "//" +top_level_keywords = [ + "abstract", + "async", + "class", + "const", + "export", + "function", + "interface", + "module", + "package", + "type", + "var", + "enum", + "let", +] + +[[config]] +languages = ["go"] +line_comment = "//" +top_level_keywords = [ + "func", + "interface", + "struct", + "package", + "type", + "import", + "var", + "const", +] \ No newline at end of file diff --git a/crates/tabby-common/src/languages.rs b/crates/tabby-common/src/languages.rs new file mode 100644 index 000000000000..b5f9c8bdb2d0 --- /dev/null +++ b/crates/tabby-common/src/languages.rs @@ -0,0 +1,73 @@ +use lazy_static::lazy_static; +use serde::Deserialize; + +lazy_static! { + static ref DEFAULT: Vec<&'static str> = vec![ + "\n\n", + "\n\n ", + "\n\n ", + "\n\n ", + "\n\n ", + "\n\n ", + "\n\n ", + "\n\n ", + "\n\n\t", + "\n\n\t\t", + "\n\n\t\t\t", + "\n\n\t\t\t\t", + "\n\n\t\t\t\t\t", + "\n\n\t\t\t\t\t\t", + "\n\n\t\t\t\t\t\t\t", + ]; +} + +#[derive(Deserialize)] +struct ConfigList { + config: Vec, +} + +#[derive(Deserialize, Debug)] +pub struct Language { + languages: Vec, + top_level_keywords: Vec, + + pub line_comment: String, +} + +impl Language { + pub fn get_stop_words(&self) -> Vec { + let mut out = vec![]; + out.push(format!("\n{}", self.line_comment)); + for word in &self.top_level_keywords { + out.push(format!("\n{}", word)); + } + + for x in DEFAULT.iter() { + out.push((*x).to_owned()); + } + + out + } + + pub fn get_hashkey(&self) -> String { + self.languages[0].clone() + } +} + +lazy_static! { + static ref CONFIG: ConfigList = + serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap(); + pub static ref UNKNOWN_LANGUAGE: Language = Language { + languages: vec!["unknown".to_owned()], + line_comment: "".to_owned(), + top_level_keywords: vec![], + }; +} + +pub fn get_language(language: &str) -> &'static Language { + CONFIG + .config + .iter() + .find(|c| c.languages.iter().any(|x| x == language)) + .unwrap_or(&UNKNOWN_LANGUAGE) +} diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index ae52e5a171b1..3bbbd5b8b658 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -1,6 +1,7 @@ pub mod config; pub mod events; pub mod index; +pub mod languages; pub mod path; pub mod usage; diff --git a/crates/tabby-inference/Cargo.toml b/crates/tabby-inference/Cargo.toml index f3aa89eefa81..0959c00d2167 100644 --- a/crates/tabby-inference/Cargo.toml +++ b/crates/tabby-inference/Cargo.toml @@ -13,3 +13,4 @@ derive_builder = "0.12.0" futures = { workspace = true } regex.workspace = true tokenizers.workspace = true +tabby-common = { path = "../tabby-common" } \ No newline at end of file diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 77bf1e2cbf08..158fab8bebb5 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -2,10 +2,11 @@ use std::sync::Arc; use dashmap::DashMap; use regex::Regex; +use tabby_common::languages::Language; use tokenizers::tokenizer::Tokenizer; pub struct DecodingFactory { - stop_regex_cache: DashMap<&'static [&'static str], Regex>, + stop_regex_cache: DashMap, } fn reverse(s: T) -> String @@ -28,32 +29,34 @@ impl DecodingFactory { &self, tokenizer: Arc, input_token_ids: &[u32], - stop_words: &'static [&'static str], + language: &'static Language, ) -> IncrementalDecoding { - IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids) + IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids) } - fn get_re(&self, stop_words: &'static [&'static str]) -> Option { + fn get_re(&self, language: &'static Language) -> Option { + let stop_words = language.get_stop_words(); if stop_words.is_empty() { None } else { - let mut re = self.stop_regex_cache.get(stop_words); + let hashkey = language.get_hashkey(); + let mut re = self.stop_regex_cache.get(&hashkey); if re.is_none() { self.stop_regex_cache - .insert(stop_words, create_stop_regex(stop_words)); - re = self.stop_regex_cache.get(stop_words); + .insert(hashkey.clone(), create_stop_regex(stop_words)); + re = self.stop_regex_cache.get(&hashkey); } re.map(|x| x.value().clone()) } } } -fn create_stop_regex(stop_words: &[&str]) -> Regex { +fn create_stop_regex(stop_words: Vec) -> Regex { // (?m) enables multi-line matching mode. // \A means absolute begins of string. let reversed_stop_words: Vec<_> = stop_words .iter() - .map(|x| regex::escape(&reverse(*x))) + .map(|x| regex::escape(&reverse(x))) .collect(); let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))"; Regex::new(®ex_string).expect("Failed to create regex") @@ -131,7 +134,12 @@ mod tests { #[test] fn test_it_works() { let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid"); - assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text)); - assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text)); + assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text)); + assert!(create_stop_regex(vec![ + "\n\n".to_owned(), + "\n\n ".to_owned(), + "\nvoid".to_owned() + ]) + .is_match(&text)); } } diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index 28ab134785ca..3c3990bf766c 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -3,6 +3,7 @@ pub mod decoding; use async_trait::async_trait; use derive_builder::Builder; use futures::stream::BoxStream; +use tabby_common::languages::Language; #[derive(Builder, Debug)] pub struct TextGenerationOptions { @@ -15,12 +16,10 @@ pub struct TextGenerationOptions { #[builder(default = "1.0")] pub sampling_temperature: f32, - #[builder(default = "&EMPTY_STOP_WORDS")] - pub stop_words: &'static [&'static str], + #[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")] + pub language: &'static Language, } -static EMPTY_STOP_WORDS: Vec<&'static str> = vec![]; - #[async_trait] pub trait TextGeneration: Sync + Send { async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String; diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 6dc5e204947d..af8307c6d223 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -1,4 +1,3 @@ -mod languages; mod prompt; use std::sync::Arc; @@ -6,12 +5,11 @@ use std::sync::Arc; use axum::{extract::State, Json}; use hyper::StatusCode; use serde::{Deserialize, Serialize}; -use tabby_common::events; +use tabby_common::{events, languages::get_language}; use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder}; use tracing::{debug, instrument}; use utoipa::ToSchema; -use self::languages::get_language; use super::search::IndexServer; #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] @@ -112,7 +110,7 @@ pub async fn completions( .max_input_length(1024 + 512) .max_decoding_length(128) .sampling_temperature(0.1) - .stop_words(get_language(&language).stop_words) + .language(get_language(&language)) .build() .unwrap(); diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs deleted file mode 100644 index 8eebce9af76c..000000000000 --- a/crates/tabby/src/serve/completions/languages.rs +++ /dev/null @@ -1,116 +0,0 @@ -use lazy_static::lazy_static; - -pub struct Language { - pub stop_words: &'static [&'static str], - pub line_comment: &'static str, -} - -lazy_static! { - static ref DEFAULT: Vec<&'static str> = vec![ - "\n\n", - "\n\n ", - "\n\n ", - "\n\n ", - "\n\n ", - "\n\n ", - "\n\n ", - "\n\n ", - "\n\n\t", - "\n\n\t\t", - "\n\n\t\t\t", - "\n\n\t\t\t\t", - "\n\n\t\t\t\t\t", - "\n\n\t\t\t\t\t\t", - "\n\n\t\t\t\t\t\t\t", - ]; - static ref UNKONWN: Language = Language { - stop_words: &DEFAULT, - line_comment: "#" - }; - - /* Python */ - static ref PYTHON_STOP_WORDS: Vec<&'static str> = - vec!["\ndef", "\n#", "\nfrom", "\nclass", "\nimport"].with_default(); - static ref PYTHON: Language = Language { - stop_words: &PYTHON_STOP_WORDS, - line_comment: "#", - }; - - /* Rust */ - static ref RUST_STOP_WORDS: Vec<&'static str> = vec![ - "\n//", "\nfn", "\ntrait", "\nimpl", "\nenum", "\npub", "\nextern", "\nstatic", - "\ntrait", "\nunsafe", "\nuse" - ] - .with_default(); - static ref RUST: Language = Language { - stop_words: &RUST_STOP_WORDS, - line_comment: "//", - }; - - /* Javascript / Typescript */ - static ref JAVASCRIPT_TYPESCRIPT_STOP_WORDS: Vec<&'static str> = vec![ - "\n//", - "\nabstract", - "\nasync", - "\nclass", - "\nconst", - "\nexport", - "\nfunction", - "\ninterface", - "\nmodule", - "\npackage", - "\ntype", - "\nvar", - "\nenum", - "\nlet", - ] - .with_default(); - static ref JAVASCRIPT_TYPESCRIPT: Language = Language { - stop_words: &JAVASCRIPT_TYPESCRIPT_STOP_WORDS, - line_comment: "//", - }; - - /* Golang */ - static ref GO_STOP_WORDS: Vec<&'static str> = vec![ - "\n//", - "\nfunc", - "\ninterface", - "\nstruct", - "\npackage", - "\ntype", - "\nimport", - "\nvar", - "\nconst", - ] - .with_default(); - static ref GO: Language = Language { - stop_words: &GO_STOP_WORDS, - line_comment: "//", - }; -} - -pub fn get_language(language: &str) -> &'static Language { - if language == "python" { - &PYTHON - } else if language == "rust" { - &RUST - } else if language == "javascript" || language == "typescript" { - &JAVASCRIPT_TYPESCRIPT - } else if language == "go" { - &GO - } else { - &UNKONWN - } -} - -trait WithDefault { - fn with_default(self) -> Self; -} - -impl WithDefault for Vec<&'static str> { - fn with_default(mut self) -> Self { - let mut x = DEFAULT.clone(); - self.append(&mut x); - self - } -} diff --git a/crates/tabby/src/serve/completions/prompt.rs b/crates/tabby/src/serve/completions/prompt.rs index bab082d59412..6788716b40f7 100644 --- a/crates/tabby/src/serve/completions/prompt.rs +++ b/crates/tabby/src/serve/completions/prompt.rs @@ -3,14 +3,12 @@ use std::sync::Arc; use lazy_static::lazy_static; use regex::Regex; use strfmt::strfmt; +use tabby_common::languages::get_language; use textdistance::Algorithm; use tracing::warn; use super::{Segments, Snippet}; -use crate::serve::{ - completions::languages::get_language, - search::{IndexServer, IndexServerError}, -}; +use crate::serve::search::{IndexServer, IndexServerError}; static MAX_SNIPPETS_TO_FETCH: usize = 20; static MAX_SNIPPET_CHARS_IN_PROMPT: usize = 768; @@ -78,7 +76,7 @@ fn build_prefix(language: &str, prefix: &str, snippets: &[Snippet]) -> String { return prefix.to_owned(); } - let comment_char = get_language(language).line_comment; + let comment_char = &get_language(language).line_comment; let mut lines: Vec = vec![]; for (i, snippet) in snippets.iter().enumerate() {