Skip to content

Commit

Permalink
Add tokenize & detokenize to client, fix typos
Browse files Browse the repository at this point in the history
- implemented client code for the `/tokenize` & `/detokenize` endpoints
- added docstring examples
andreaskoepf authored and pacman82 committed Nov 30, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ececcef commit 65d764b
Showing 8 changed files with 296 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/completion.rs
Original file line number Diff line number Diff line change
@@ -76,7 +76,7 @@ pub struct Stopping<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
pub stop_sequences: &'a [&'a str],
@@ -95,7 +95,7 @@ impl<'a> Stopping<'a> {
/// Body send to the Aleph Alpha API on the POST `/completion` Route
#[derive(Serialize, Debug)]
struct BodyCompletion<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminus-base`.
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// Prompt to complete. The modalities supported depend on `model`.
pub prompt: Prompt<'a>,
@@ -104,7 +104,7 @@ struct BodyCompletion<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
#[serde(skip_serializing_if = "<[_]>::is_empty")]
57 changes: 57 additions & 0 deletions src/detokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::detokenize] request.
pub struct TaskDetokenization<'a> {
/// List of token ids which should be detokenized into text.
pub token_ids: &'a [u32],
}

/// Body send to the Aleph Alpha API on the POST `/detokenize` route
#[derive(Serialize, Debug)]
struct BodyDetokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// List of ids to detokenize.
pub token_ids: &'a [u32],
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseDetokenization {
pub result: String,
}

#[derive(Debug, PartialEq, Eq)]
pub struct DetokenizationOutput {
pub result: String,
}

impl From<ResponseDetokenization> for DetokenizationOutput {
fn from(response: ResponseDetokenization) -> Self {
Self {
result: response.result,
}
}
}

impl<'a> Task for TaskDetokenization<'a> {
type Output = DetokenizationOutput;
type ResponseBody = ResponseDetokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyDetokenization {
model,
token_ids: &self.token_ids,
};
client.post(format!("{base}/detokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
DetokenizationOutput::from(response)
}
}
2 changes: 1 addition & 1 deletion src/explanation.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ pub struct TaskExplanation<'a> {
/// The target string that should be explained. The influence of individual parts
/// of the prompt for generating this target string will be indicated in the response.
pub target: &'a str,
/// Granularity paramaters for the explanation
/// Granularity parameters for the explanation
pub granularity: Granularity,
}

16 changes: 8 additions & 8 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -10,8 +10,8 @@ use crate::How;
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
/// executed on. This allows this trait to hold in the presence of services, which use more than one
/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
/// can not implement this trait directly, since its result would depend on what model is choosen to
/// execute it. You can remidy this by turning completion task into a job, calling
/// can not implement this trait directly, since its result would depend on what model is chosen to
/// execute it. You can remedy this by turning completion task into a job, calling
/// [`Task::with_model`].
pub trait Job {
/// Output returned by [`crate::Client::output_of`]
@@ -130,7 +130,7 @@ impl HttpClient {
let query = if how.be_nice {
[("nice", "true")].as_slice()
} else {
// nice=false is default, so we just ommit it.
// nice=false is default, so we just omit it.
[].as_slice()
};
let response = task
@@ -156,7 +156,7 @@ impl HttpClient {
async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
let status = response.status();
if !status.is_success() {
// Store body in a variable, so we can use it, even if it is not an Error emmitted by
// Store body in a variable, so we can use it, even if it is not an Error emitted by
// the API, but an intermediate Proxy like NGinx, so we can still forward the error
// message.
let body = response.text().await?;
@@ -174,14 +174,14 @@ async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Re
}
}

/// We are only interessted in the status codes of the API.
/// We are only interested in the status codes of the API.
#[derive(Deserialize)]
struct ApiError<'a> {
/// Unique string in capital letters emitted by the API to signal different kinds of errors in a
/// finer granualrity then the HTTP status codes alone would allow for.
/// finer granularity then the HTTP status codes alone would allow for.
///
/// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
/// are 429 (the former is emmited by NGinx though).
/// are 429 (the former is emitted by NGinx though).
_code: Cow<'a, str>,
}

@@ -204,7 +204,7 @@ pub enum Error {
Busy,
#[error("No response received within given timeout: {0:?}")]
ClientTimeout(Duration),
/// An error on the Http Protocl level.
/// An error on the Http Protocol level.
#[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
Http { status: u16, body: String },
/// Most likely either TLS errors creating the Client, or IO errors.
76 changes: 75 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -24,11 +24,13 @@
//! ```
mod completion;
mod detokenization;
mod explanation;
mod http;
mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod tokenization;

use std::time::Duration;

@@ -37,6 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};

pub use self::{
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
PromptGranularity, TaskExplanation, TextScore,
@@ -46,6 +49,7 @@ pub use self::{
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
tokenization::{TaskTokenization, TokenizationOutput},
};

/// Execute Jobs against the Aleph Alpha API
@@ -215,6 +219,76 @@ impl Client {
.output_of(&task.with_model(model), how)
.await
}

/// Tokenize a prompt for a specific model.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
///
/// async fn tokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Name of the model for which we want to tokenize text.
/// let model = "luminous-base";
///
/// // Text prompt to be tokenized.
/// let prompt = "An apple a day";
///
/// let task = TaskTokenization {
/// prompt,
/// tokens: true, // return text-tokens
/// token_ids: true, // return numeric token-ids
/// };
/// let respones = client.tokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn tokenize(
&self,
task: &TaskTokenization<'_>,
model: &str,
how: &How,
) -> Result<TokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}

/// Detokenize a list of token ids into a string.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
///
/// async fn detokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Specify the name of the model whose tokenizer was used to generate the input token ids.
/// let model = "luminous-base";
///
/// // Token ids to convert into text.
/// let token_ids: Vec<u32> = vec![556, 48741, 247, 2983];
///
/// let task = TaskDetokenization {
/// token_ids: &token_ids,
/// };
/// let respones = client.detokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn detokenize(
&self,
task: &TaskDetokenization<'_>,
model: &str,
how: &How,
) -> Result<DetokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}
}

/// Controls of how to execute a task
@@ -254,7 +328,7 @@ impl Default for How {
/// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How
/// };
///
/// async fn semanitc_search_with_luminous_base(client: &Client) {
/// async fn semantic_search_with_luminous_base(client: &Client) {
/// // Given
/// let robot_fact = Prompt::from_text(
/// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \
91 changes: 91 additions & 0 deletions src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::tokenize] request.
pub struct TaskTokenization<'a> {
/// The text prompt which should be converted into tokens
pub prompt: &'a str,

/// Specify `true` to return text-tokens.
pub tokens: bool,

/// Specify `true` to return numeric token-ids.
pub token_ids: bool,
}

impl<'a> From<&'a str> for TaskTokenization<'a> {
fn from(prompt: &'a str) -> TaskTokenization {
TaskTokenization {
prompt,
tokens: true,
token_ids: true,
}
}
}

impl TaskTokenization<'_> {
pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization {
TaskTokenization {
prompt,
tokens,
token_ids,
}
}
}

#[derive(Serialize, Debug)]
struct BodyTokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base`.
pub model: &'a str,
/// String to tokenize.
pub prompt: &'a str,
/// Set this value to `true` to return text-tokens.
pub tokens: bool,
/// Set this value to `true` to return numeric token-ids.
pub token_ids: bool,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseTokenization {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

#[derive(Debug, PartialEq)]
pub struct TokenizationOutput {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

impl From<ResponseTokenization> for TokenizationOutput {
fn from(response: ResponseTokenization) -> Self {
Self {
tokens: response.tokens,
token_ids: response.token_ids,
}
}
}

impl Task for TaskTokenization<'_> {
type Output = TokenizationOutput;
type ResponseBody = ResponseTokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyTokenization {
model,
prompt: &self.prompt,
tokens: self.tokens,
token_ids: self.token_ids,
};
client.post(format!("{base}/tokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
TokenizationOutput::from(response)
}
}
59 changes: 58 additions & 1 deletion tests/integration.rs
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@ use std::{fs::File, io::BufReader};
use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskCompletion, TaskExplanation, TaskSemanticEmbedding, TextScore,
TaskBatchSemanticEmbedding, TaskCompletion, TaskDetokenization, TaskExplanation,
TaskSemanticEmbedding, TaskTokenization, TextScore,
};
use dotenv::dotenv;
use image::ImageFormat;
@@ -376,3 +377,59 @@ async fn batch_semanitc_embed_with_luminous_base() {
// There should be 2 embeddings
assert_eq!(embeddings.len(), 2);
}

#[tokio::test]
async fn tokenization_with_luminous_base() {
// Given
let input = "Hello, World!";

let client = Client::new(&AA_API_TOKEN).unwrap();

// When
let task1 = TaskTokenization::new(input, false, true);
let task2 = TaskTokenization::new(input, true, false);

let response1 = client
.tokenize(&task1, "luminous-base", &How::default())
.await
.unwrap();

let response2 = client
.tokenize(&task2, "luminous-base", &How::default())
.await
.unwrap();

// Then
assert_eq!(response1.tokens, None);
assert_eq!(response1.token_ids, Some(vec![49222, 15, 5390, 4]));

assert_eq!(response2.token_ids, None);
assert_eq!(
response2.tokens,
Some(
vec!["ĠHello", ",", "ĠWorld", "!"]
.into_iter()
.map(str::to_owned)
.collect()
)
);
}

#[tokio::test]
async fn detokenization_with_luminous_base() {
// Given
let input = vec![49222, 15, 5390, 4];

let client = Client::new(&AA_API_TOKEN).unwrap();

// When
let task = TaskDetokenization { token_ids: &input };

let response = client
.detokenize(&task, "luminous-base", &How::default())
.await
.unwrap();

// Then
assert!(response.result.contains("Hello, World!"));
}
6 changes: 3 additions & 3 deletions tests/unit.rs
Original file line number Diff line number Diff line change
@@ -45,10 +45,10 @@ async fn completion_with_luminous_base() {
assert_eq!("\n", actual)
}

/// If we open too many requests at once, we may trigger rate limmiting. We want this scenario to be
/// If we open too many requests at once, we may trigger rate limiting. We want this scenario to be
/// easily detectible by the user, so he/she/it can start sending requests slower.
#[tokio::test]
async fn detect_rate_limmiting() {
async fn detect_rate_limiting() {
// Given

// Start a background HTTP server on a random local part
@@ -84,7 +84,7 @@ async fn detect_rate_limmiting() {
assert!(matches!(error, Error::TooManyRequests));
}

/// Even if we do not open too many requests at once ourselfes, the API may just be busy. We also
/// Even if we do not open too many requests at once ourselves, the API may just be busy. We also
/// want this scenario to be easily detectable by users.
#[tokio::test]
async fn detect_queue_full() {

0 comments on commit 65d764b

Please sign in to comment.