Skip to content

Commit

Permalink
Support fetching tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman82 committed Sep 5, 2024
1 parent e5e56e2 commit 75e0b76
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 1 deletion.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "aleph-alpha-client"
version = "0.11.0"
version = "0.12.0"
edition = "2021"
description = "Interact with large language models provided by the Aleph Alpha API in Rust code"
license = "MIT"
Expand All @@ -18,6 +18,7 @@ reqwest = { version = "0.12.3", features = ["json"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
thiserror = "1.0.58"
tokenizers = { version = "0.20.0", default-features = false, features = ["onig", "esaxx_fast"] }

[dev-dependencies]
dotenv = "0.15.0"
Expand Down
4 changes: 4 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

## 0.12.0

* Add `Client::tokenizer_by_model` to fetch the Tokenizer for a given model name

## 0.11.0

* Add `with_maximum_tokens` method to `Prompt`
Expand Down
20 changes: 20 additions & 0 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::{borrow::Cow, time::Duration};
use reqwest::{header, ClientBuilder, RequestBuilder, StatusCode};
use serde::Deserialize;
use thiserror::Error as ThisError;
use tokenizers::Tokenizer;

use crate::How;

Expand Down Expand Up @@ -163,6 +164,21 @@ impl HttpClient {
auth_value.set_sensitive(true);
auth_value
}

pub async fn tokenizer_by_model(&self, model: &str, api_token: Option<String> ) -> Result<Tokenizer, Error> {
let api_token = api_token
.as_ref()
.or(self.api_token.as_ref())
.expect("API token needs to be set on client construction or per request");
let response = self.http.get(format!("{}/models/{model}/tokenizer", self.base))
.header(header::AUTHORIZATION, Self::header_from_token(api_token)).send().await?;
let response = translate_http_error(response).await?;
let bytes = response.bytes().await?;
let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
Error::InvalidTokenizer { deserialization_error: e.to_string() }
})?;
Ok(tokenizer)
}
}

async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
Expand Down Expand Up @@ -235,6 +251,10 @@ pub enum Error {
/// An error on the Http Protocol level.
#[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
Http { status: u16, body: String },
#[error("Tokenizer could not be correctly deserialized. Caused by:\n{}", deserialization_error)]
InvalidTokenizer {
deserialization_error: String,
},
/// Most likely either TLS errors creating the Client, or IO errors.
#[error(transparent)]
Other(#[from] reqwest::Error),
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use std::time::Duration;

use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;

pub use self::{
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
Expand Down Expand Up @@ -303,6 +304,10 @@ impl Client {
.output_of(&task.with_model(model), how)
.await
}

pub async fn tokenizer_by_model(&self, model: &str, api_token: Option<String>) -> Result<Tokenizer, Error> {
self.http_client.tokenizer_by_model(model, api_token).await
}
}

/// Controls of how to execute a task
Expand Down
15 changes: 15 additions & 0 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,18 @@ async fn detokenization_with_luminous_base() {
// Then
assert!(response.result.contains("Hello, World!"));
}

#[tokio::test]
async fn fetch_tokenizer_for_pharia_1_llm_7b() {
// Given
let client = Client::with_authentication(api_token()).unwrap();

// When
let tokenizer = client
.tokenizer_by_model("Pharia-1-LLM-7B-control", None)
.await
.unwrap();

// Then
assert_eq!(128_000, tokenizer.get_vocab_size(true));
}

0 comments on commit 75e0b76

Please sign in to comment.