Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extract language configuration into individual toml file #564

Merged
merged 2 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/ctranslate2-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl TextGeneration for CTranslate2Engine {
let decoding = self.decoding_factory.create_incremental_decoding(
self.tokenizer.clone(),
truncate_tokens(encoding.get_ids(), options.max_input_length),
options.stop_words,
options.language,
);

let cancel = CancellationToken::new();
Expand Down
3 changes: 0 additions & 3 deletions crates/http-api-bindings/src/fastchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ impl FastChatEngine {
#[async_trait]
impl TextGeneration for FastChatEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let _stop_sequences: Vec<String> =
options.stop_words.iter().map(|x| x.to_string()).collect();

let tokens: Vec<&str> = prompt.split("<MID>").collect();
let request = Request {
model: self.model_name.to_owned(),
Expand Down
3 changes: 2 additions & 1 deletion crates/http-api-bindings/src/vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ impl VertexAIEngine {
impl TextGeneration for VertexAIEngine {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String {
let stop_sequences: Vec<String> = options
.stop_words
.language
.get_stop_words()
.iter()
.map(|x| x.to_string())
// vertex supports at most 5 stop sequence.
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl TextGeneration for LlamaEngine {

let input_token_ids = truncate_tokens(encoding.get_ids(), options.max_input_length);
engine.as_mut().start(input_token_ids);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.stop_words);
let mut decoding = self.decoding_factory.create_incremental_decoding(self.tokenizer.clone(), input_token_ids, options.language);
let mut n_remains = options.max_decoding_length ;
while n_remains > 0 {
let Ok(next_token_id) = engine.as_mut().step() else {
Expand Down
53 changes: 53 additions & 0 deletions crates/tabby-common/assets/languages.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
[[config]]
languages = ["python"]
line_comment = "#"
top_level_keywords = ["def", "from", "class", "import"]

[[config]]
languages = ["rust"]
line_comment = "//"
top_level_keywords = [
"fn",
"trait",
"impl",
"enum",
"pub",
"extern",
"static",
"trait",
"unsafe",
"use",
]

[[config]]
languages = ["javascript", "typescript", "javascriptreact", "typescriptreact"]
line_comment = "//"
top_level_keywords = [
"abstract",
"async",
"class",
"const",
"export",
"function",
"interface",
"module",
"package",
"type",
"var",
"enum",
"let",
]

[[config]]
languages = ["go"]
line_comment = "//"
top_level_keywords = [
"func",
"interface",
"struct",
"package",
"type",
"import",
"var",
"const",
]
73 changes: 73 additions & 0 deletions crates/tabby-common/src/languages.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use lazy_static::lazy_static;
use serde::Deserialize;

lazy_static! {
static ref DEFAULT: Vec<&'static str> = vec![
"\n\n",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n ",
"\n\n\t",
"\n\n\t\t",
"\n\n\t\t\t",
"\n\n\t\t\t\t",
"\n\n\t\t\t\t\t",
"\n\n\t\t\t\t\t\t",
"\n\n\t\t\t\t\t\t\t",
];
}

#[derive(Deserialize)]
struct ConfigList {
config: Vec<Language>,
}

#[derive(Deserialize, Debug)]
pub struct Language {
languages: Vec<String>,
top_level_keywords: Vec<String>,

pub line_comment: String,
}

impl Language {
pub fn get_stop_words(&self) -> Vec<String> {
let mut out = vec![];
out.push(format!("\n{}", self.line_comment));
for word in &self.top_level_keywords {
out.push(format!("\n{}", word));
}

for x in DEFAULT.iter() {
out.push((*x).to_owned());
}

out
}

pub fn get_hashkey(&self) -> String {
self.languages[0].clone()
}
}

lazy_static! {
static ref CONFIG: ConfigList =
serdeconv::from_toml_str(include_str!("../assets/languages.toml")).unwrap();
pub static ref UNKNOWN_LANGUAGE: Language = Language {
languages: vec!["unknown".to_owned()],
line_comment: "".to_owned(),
top_level_keywords: vec![],
};
}

pub fn get_language(language: &str) -> &'static Language {
CONFIG
.config
.iter()
.find(|c| c.languages.iter().any(|x| x == language))
.unwrap_or(&UNKNOWN_LANGUAGE)
}
1 change: 1 addition & 0 deletions crates/tabby-common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod config;
pub mod events;
pub mod index;
pub mod languages;
pub mod path;
pub mod usage;

Expand Down
1 change: 1 addition & 0 deletions crates/tabby-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ derive_builder = "0.12.0"
futures = { workspace = true }
regex.workspace = true
tokenizers.workspace = true
tabby-common = { path = "../tabby-common" }
30 changes: 19 additions & 11 deletions crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use std::sync::Arc;

use dashmap::DashMap;
use regex::Regex;
use tabby_common::languages::Language;
use tokenizers::tokenizer::Tokenizer;

pub struct DecodingFactory {
stop_regex_cache: DashMap<&'static [&'static str], Regex>,
stop_regex_cache: DashMap<String, Regex>,
}

fn reverse<T>(s: T) -> String
Expand All @@ -28,32 +29,34 @@ impl DecodingFactory {
&self,
tokenizer: Arc<Tokenizer>,
input_token_ids: &[u32],
stop_words: &'static [&'static str],
language: &'static Language,
) -> IncrementalDecoding {
IncrementalDecoding::new(tokenizer, self.get_re(stop_words), input_token_ids)
IncrementalDecoding::new(tokenizer, self.get_re(language), input_token_ids)
}

fn get_re(&self, stop_words: &'static [&'static str]) -> Option<Regex> {
fn get_re(&self, language: &'static Language) -> Option<Regex> {
let stop_words = language.get_stop_words();
if stop_words.is_empty() {
None
} else {
let mut re = self.stop_regex_cache.get(stop_words);
let hashkey = language.get_hashkey();
let mut re = self.stop_regex_cache.get(&hashkey);
if re.is_none() {
self.stop_regex_cache
.insert(stop_words, create_stop_regex(stop_words));
re = self.stop_regex_cache.get(stop_words);
.insert(hashkey.clone(), create_stop_regex(stop_words));
re = self.stop_regex_cache.get(&hashkey);
}
re.map(|x| x.value().clone())
}
}
}

fn create_stop_regex(stop_words: &[&str]) -> Regex {
fn create_stop_regex(stop_words: Vec<String>) -> Regex {
// (?m) enables multi-line matching mode.
// \A means absolute begins of string.
let reversed_stop_words: Vec<_> = stop_words
.iter()
.map(|x| regex::escape(&reverse(*x)))
.map(|x| regex::escape(&reverse(x)))
.collect();
let regex_string = r"(?m)\A".to_owned() + "((" + &reversed_stop_words.join(")|(") + "))";
Regex::new(&regex_string).expect("Failed to create regex")
Expand Down Expand Up @@ -131,7 +134,12 @@ mod tests {
#[test]
fn test_it_works() {
let text = reverse("void write_u32(std::uint32_t val) const {\n write_raw(&val, sizeof(val));\n }\n\n ~llama_file() {\n if (fp) {\n std::fclose(fp);\n }\n }\n};\n\nvoid");
assert!(!create_stop_regex(&["\n\n", "\n\n "]).is_match(&text));
assert!(create_stop_regex(&["\n\n", "\n\n ", "\nvoid"]).is_match(&text));
assert!(!create_stop_regex(vec!["\n\n".to_owned(), "\n\n ".to_owned()]).is_match(&text));
assert!(create_stop_regex(vec![
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned()
])
.is_match(&text));
}
}
7 changes: 3 additions & 4 deletions crates/tabby-inference/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod decoding;
use async_trait::async_trait;
use derive_builder::Builder;
use futures::stream::BoxStream;
use tabby_common::languages::Language;

#[derive(Builder, Debug)]
pub struct TextGenerationOptions {
Expand All @@ -15,12 +16,10 @@ pub struct TextGenerationOptions {
#[builder(default = "1.0")]
pub sampling_temperature: f32,

#[builder(default = "&EMPTY_STOP_WORDS")]
pub stop_words: &'static [&'static str],
#[builder(default = "&tabby_common::languages::UNKNOWN_LANGUAGE")]
pub language: &'static Language,
}

static EMPTY_STOP_WORDS: Vec<&'static str> = vec![];

#[async_trait]
pub trait TextGeneration: Sync + Send {
async fn generate(&self, prompt: &str, options: TextGenerationOptions) -> String;
Expand Down
6 changes: 2 additions & 4 deletions crates/tabby/src/serve/completions.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
mod languages;
mod prompt;

use std::sync::Arc;

use axum::{extract::State, Json};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use tabby_common::events;
use tabby_common::{events, languages::get_language};
use tabby_inference::{TextGeneration, TextGenerationOptionsBuilder};
use tracing::{debug, instrument};
use utoipa::ToSchema;

use self::languages::get_language;
use super::search::IndexServer;

#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
Expand Down Expand Up @@ -112,7 +110,7 @@ pub async fn completions(
.max_input_length(1024 + 512)
.max_decoding_length(128)
.sampling_temperature(0.1)
.stop_words(get_language(&language).stop_words)
.language(get_language(&language))
.build()
.unwrap();

Expand Down
Loading
Loading