Skip to content

Commit

Permalink
refactor: extract language configuration into individual toml file
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed Oct 15, 2023
1 parent 9d6a9a6 commit 4672836
Show file tree
Hide file tree
Showing 14 changed files with 147 additions and 146 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/ctranslate2-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
options.language,
);

let cancel = CancellationToken::new();
Expand Down
3 changes: 0 additions & 3 deletions crates/http-api-bindings/src/fastchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();

let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
model: self.model_name.to_owned(),
Expand Down
3 changes: 2 additions & 1 deletion crates/http-api-bindings/src/vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.stop_words
.language
.get_stop_words()
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine {

let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let Ok(next_token_id) = engine.as_mut().step() else {
Expand Down
40 changes: 40 additions & 0 deletions crates/tabby-common/assets/languages.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[[config]]
languages = ["python"]
line_comment = "#"
top_level_keywords = ["def", "from", "class", "import"]

[[config]]
languages = ["rust"]
line_comment = "//"
top_level_keywords = [
"fn",
"trait",
"impl",
"enum",
"pub",
"extern",
"static",
"trait",
"unsafe",
"use",
]

[[config]]
languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"]
line_comment = "//"
top_level_keywords = [
"//",
"abstract",
"async",
"class",
"const",
"export",
"function",
"interface",
"module",
"package",
"type",
"var",
"enum",
"let",
]
73 changes: 73 additions & 0 deletions crates/tabby-common/src/languages.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use lazy_static::lazy_static;
use serde::Deserialize;

lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
}

#[derive(Deserialize)]
struct ConfigList {
config: Vec<Language>,
}

#[derive(Deserialize, Debug)]
pub struct Language {
languages: Vec<String>,
top_level_keywords: Vec<String>,

pub line_comment: String,
}

impl Language {
pub fn get_stop_words(&self) -> Vec<String> {
let mut out = vec![];
out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords {
out.push(format!("\n{}", word));
}

for x in DEFAULT.iter() {
out.push((*x).to_owned());
}

out
}

pub fn get_hashkey(&self) -> String {
self.languages[0].clone()
}
}

lazy_static! {
static ref CONFIG: ConfigList =
serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap();
pub static ref UNKNOWN_LANGUAGE: Language = Language {
languages: vec!["unknown".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
}

pub fn get_language(language: &str) -> &'static Language {
CONFIG
.config
.iter()
.find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE)
}
1 change: 1 addition & 0 deletions crates/tabby-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod config;
pub mod events;
pub mod index;
pub mod languages;
pub mod path;
pub mod usage;

Expand Down
1 change: 1 addition & 0 deletions crates/tabby-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ derive_builder = "0.12.0"
futures = { workspace = true }
regex.workspace = true
tokenizers.workspace = true
tabby-common = { path = "../tabby-common" }
30 changes: 19 additions & 11 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use dashmap::DashMap;
use regex::Regex;
use tabby_common::languages::Language;
use tokenizers::tokenizer::Tokenizer;

pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
stop_regex_cache: DashMap<String, Regex>,
}

fn reverse<T>(s: T) -> String
Expand All @@ -28,32 +29,34 @@ impl DecodingFactory {
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static [&'static str],
language: &'static Language,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
}

fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
fn get_re(&self, language: &'static Language) -> Option<Regex> {
let stop_words = language.get_stop_words();
if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
let hashkey = language.get_hashkey();
let mut re = self.stop_regex_cache.get(&hashkey);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
.insert(hashkey.clone(), create_stop_regex(stop_words));
re = self.stop_regex_cache.get(&hashkey);
}
re.map(|x| x.value().clone())
}
}
}

fn create_stop_regex(stop_words: &[&str]) -> Regex {
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let reversed_stop_words: Vec<_> = stop_words
.iter()
.map(|x| regex::escape(&reverse(*x)))
.map(|x| regex::escape(&reverse(x)))
.collect();
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
Regex::new(&regex_string).expect("Failed to create regex")
Expand Down Expand Up @@ -131,7 +134,12 @@ mod tests {
#[test]
fn test_it_works() {
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text));
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text));
assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
assert!(create_stop_regex(vec![
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned()
])
.is_match(&text));
}
}
7 changes: 3 additions & 4 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod decoding;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_common::languages::Language;

#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
Expand All @@ -15,12 +16,10 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,

#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static [&'static str],
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
pub language: &'static Language,
}

static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
Expand Down
6 changes: 2 additions & 4 deletions crates/tabby/src/serve/completions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
mod languages;
mod prompt;

use std::sync::Arc;

use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::events;
use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;

use self::languages::get_language;
use super::search::IndexServer;

#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
Expand Down Expand Up @@ -112,7 +110,7 @@ pub async fn completions(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_language(&language).stop_words)
.language(get_language(&language))
.build()
.unwrap();

Expand Down
Loading

0 comments on commit 4672836

Please sign in to comment.