diff --git a/src/completion.rs b/src/completion.rs index 600d6d4..3b68ac2 100644 --- a/src/completion.rs +++ b/src/completion.rs @@ -76,7 +76,7 @@ pub struct Stopping<'a> { /// List of strings which will stop generation if they are generated. Stop sequences are /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of /// lines starting with either "Question: " or "Answer: " (alternating). After producing an - /// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used + /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used /// as stop sequence in order not to have the model generate more questions but rather restrict /// text generation to the answers. pub stop_sequences: &'a [&'a str], @@ -95,7 +95,7 @@ impl<'a> Stopping<'a> { /// Body send to the Aleph Alpha API on the POST `/completion` Route #[derive(Serialize, Debug)] struct BodyCompletion<'a> { - /// Name of the model tasked with completing the prompt. E.g. `luminus-base`. + /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`. pub model: &'a str, /// Prompt to complete. The modalities supported depend on `model`. pub prompt: Prompt<'a>, @@ -104,7 +104,7 @@ struct BodyCompletion<'a> { /// List of strings which will stop generation if they are generated. Stop sequences are /// helpful in structured texts. E.g.: In a question answering scenario a text may consist of /// lines starting with either "Question: " or "Answer: " (alternating). After producing an - /// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used + /// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used /// as stop sequence in order not to have the model generate more questions but rather restrict /// text generation to the answers. #[serde(skip_serializing_if = "<[_]>::is_empty")] diff --git a/src/detokenization.rs b/src/detokenization.rs new file mode 100644 index 0000000..5e3e64d --- /dev/null +++ b/src/detokenization.rs @@ -0,0 +1,57 @@ +use crate::Task; +use serde::{Deserialize, Serialize}; + +/// Input for a [crate::Client::detokenize] request. +pub struct TaskDetokenization<'a> { + /// List of token ids which should be detokenized into text. + pub token_ids: &'a [u32], +} + +/// Body send to the Aleph Alpha API on the POST `/detokenize` route +#[derive(Serialize, Debug)] +struct BodyDetokenization<'a> { + /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`. + pub model: &'a str, + /// List of ids to detokenize. + pub token_ids: &'a [u32], +} + +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct ResponseDetokenization { + pub result: String, +} + +#[derive(Debug, PartialEq, Eq)] +pub struct DetokenizationOutput { + pub result: String, +} + +impl From for DetokenizationOutput { + fn from(response: ResponseDetokenization) -> Self { + Self { + result: response.result, + } + } +} + +impl<'a> Task for TaskDetokenization<'a> { + type Output = DetokenizationOutput; + type ResponseBody = ResponseDetokenization; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = BodyDetokenization { + model, + token_ids: &self.token_ids, + }; + client.post(format!("{base}/detokenize")).json(&body) + } + + fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output { + DetokenizationOutput::from(response) + } +} diff --git a/src/explanation.rs b/src/explanation.rs index 3779455..7720a10 100644 --- a/src/explanation.rs +++ b/src/explanation.rs @@ -9,7 +9,7 @@ pub struct TaskExplanation<'a> { /// The target string that should be explained. The influence of individual parts /// of the prompt for generating this target string will be indicated in the response. pub target: &'a str, - /// Granularity paramaters for the explanation + /// Granularity parameters for the explanation pub granularity: Granularity, } diff --git a/src/http.rs b/src/http.rs index f8edfbd..1f91928 100644 --- a/src/http.rs +++ b/src/http.rs @@ -10,8 +10,8 @@ use crate::How; /// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is /// executed on. This allows this trait to hold in the presence of services, which use more than one /// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`] -/// can not implement this trait directly, since its result would depend on what model is choosen to -/// execute it. You can remidy this by turning completion task into a job, calling +/// can not implement this trait directly, since its result would depend on what model is chosen to +/// execute it. You can remedy this by turning completion task into a job, calling /// [`Task::with_model`]. pub trait Job { /// Output returned by [`crate::Client::output_of`] @@ -130,7 +130,7 @@ impl HttpClient { let query = if how.be_nice { [("nice", "true")].as_slice() } else { - // nice=false is default, so we just ommit it. + // nice=false is default, so we just omit it. [].as_slice() }; let response = task @@ -156,7 +156,7 @@ impl HttpClient { async fn translate_http_error(response: reqwest::Response) -> Result { let status = response.status(); if !status.is_success() { - // Store body in a variable, so we can use it, even if it is not an Error emmitted by + // Store body in a variable, so we can use it, even if it is not an Error emitted by // the API, but an intermediate Proxy like NGinx, so we can still forward the error // message. let body = response.text().await?; @@ -174,14 +174,14 @@ async fn translate_http_error(response: reqwest::Response) -> Result { /// Unique string in capital letters emitted by the API to signal different kinds of errors in a - /// finer granualrity then the HTTP status codes alone would allow for. + /// finer granularity then the HTTP status codes alone would allow for. /// /// E.g. Differentiating between request rate limiting and parallel tasks limiting which both - /// are 429 (the former is emmited by NGinx though). + /// are 429 (the former is emitted by NGinx though). _code: Cow<'a, str>, } @@ -204,7 +204,7 @@ pub enum Error { Busy, #[error("No response received within given timeout: {0:?}")] ClientTimeout(Duration), - /// An error on the Http Protocl level. + /// An error on the Http Protocol level. #[error("HTTP request failed with status code {}. Body:\n{}", status, body)] Http { status: u16, body: String }, /// Most likely either TLS errors creating the Client, or IO errors. diff --git a/src/lib.rs b/src/lib.rs index c42927a..fe5d8d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,11 +24,13 @@ //! ``` mod completion; +mod detokenization; mod explanation; mod http; mod image_preprocessing; mod prompt; mod semantic_embedding; +mod tokenization; use std::time::Duration; @@ -37,6 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; pub use self::{ completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, + detokenization::{DetokenizationOutput, TaskDetokenization}, explanation::{ Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation, PromptGranularity, TaskExplanation, TextScore, @@ -46,6 +49,7 @@ pub use self::{ semantic_embedding::{ SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding, }, + tokenization::{TaskTokenization, TokenizationOutput}, }; /// Execute Jobs against the Aleph Alpha API @@ -215,6 +219,76 @@ impl Client { .output_of(&task.with_model(model), how) .await } + + /// Tokenize a prompt for a specific model. + /// + /// ```no_run + /// use aleph_alpha_client::{Client, Error, How, TaskTokenization}; + /// + /// async fn tokenize() -> Result<(), Error> { + /// let client = Client::new(AA_API_TOKEN)?; + /// + /// // Name of the model for which we want to tokenize text. + /// let model = "luminous-base"; + /// + /// // Text prompt to be tokenized. + /// let prompt = "An apple a day"; + /// + /// let task = TaskTokenization { + /// prompt, + /// tokens: true, // return text-tokens + /// token_ids: true, // return numeric token-ids + /// }; + /// let respones = client.tokenize(&task, model, &How::default()).await?; + /// + /// dbg!(&respones); + /// Ok(()) + /// } + /// ``` + pub async fn tokenize( + &self, + task: &TaskTokenization<'_>, + model: &str, + how: &How, + ) -> Result { + self.http_client + .output_of(&task.with_model(model), how) + .await + } + + /// Detokenize a list of token ids into a string. + /// + /// ```no_run + /// use aleph_alpha_client::{Client, Error, How, TaskDetokenization}; + /// + /// async fn detokenize() -> Result<(), Error> { + /// let client = Client::new(AA_API_TOKEN)?; + /// + /// // Specify the name of the model whose tokenizer was used to generate the input token ids. + /// let model = "luminous-base"; + /// + /// // Token ids to convert into text. + /// let token_ids: Vec = vec![556, 48741, 247, 2983]; + /// + /// let task = TaskDetokenization { + /// token_ids: &token_ids, + /// }; + /// let respones = client.detokenize(&task, model, &How::default()).await?; + /// + /// dbg!(&respones); + /// Ok(()) + /// } + /// ``` + pub async fn detokenize( + &self, + task: &TaskDetokenization<'_>, + model: &str, + how: &How, + ) -> Result { + self.http_client + .output_of(&task.with_model(model), how) + .await + } } /// Controls of how to execute a task @@ -254,7 +328,7 @@ impl Default for How { /// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How /// }; /// -/// async fn semanitc_search_with_luminous_base(client: &Client) { +/// async fn semantic_search_with_luminous_base(client: &Client) { /// // Given /// let robot_fact = Prompt::from_text( /// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \ diff --git a/src/tokenization.rs b/src/tokenization.rs new file mode 100644 index 0000000..e4fc2a2 --- /dev/null +++ b/src/tokenization.rs @@ -0,0 +1,91 @@ +use crate::Task; +use serde::{Deserialize, Serialize}; + +/// Input for a [crate::Client::tokenize] request. +pub struct TaskTokenization<'a> { + /// The text prompt which should be converted into tokens + pub prompt: &'a str, + + /// Specify `true` to return text-tokens. + pub tokens: bool, + + /// Specify `true` to return numeric token-ids. + pub token_ids: bool, +} + +impl<'a> From<&'a str> for TaskTokenization<'a> { + fn from(prompt: &'a str) -> TaskTokenization { + TaskTokenization { + prompt, + tokens: true, + token_ids: true, + } + } +} + +impl TaskTokenization<'_> { + pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization { + TaskTokenization { + prompt, + tokens, + token_ids, + } + } +} + +#[derive(Serialize, Debug)] +struct BodyTokenization<'a> { + /// Name of the model tasked with completing the prompt. E.g. `luminous-base`. + pub model: &'a str, + /// String to tokenize. + pub prompt: &'a str, + /// Set this value to `true` to return text-tokens. + pub tokens: bool, + /// Set this value to `true` to return numeric token-ids. + pub token_ids: bool, +} + +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct ResponseTokenization { + pub tokens: Option>, + pub token_ids: Option>, +} + +#[derive(Debug, PartialEq)] +pub struct TokenizationOutput { + pub tokens: Option>, + pub token_ids: Option>, +} + +impl From for TokenizationOutput { + fn from(response: ResponseTokenization) -> Self { + Self { + tokens: response.tokens, + token_ids: response.token_ids, + } + } +} + +impl Task for TaskTokenization<'_> { + type Output = TokenizationOutput; + type ResponseBody = ResponseTokenization; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = BodyTokenization { + model, + prompt: &self.prompt, + tokens: self.tokens, + token_ids: self.token_ids, + }; + client.post(format!("{base}/tokenize")).json(&body) + } + + fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output { + TokenizationOutput::from(response) + } +} diff --git a/tests/integration.rs b/tests/integration.rs index 9972bec..64c5fa7 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -3,7 +3,8 @@ use std::{fs::File, io::BufReader}; use aleph_alpha_client::{ cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, - TaskBatchSemanticEmbedding, TaskCompletion, TaskExplanation, TaskSemanticEmbedding, TextScore, + TaskBatchSemanticEmbedding, TaskCompletion, TaskDetokenization, TaskExplanation, + TaskSemanticEmbedding, TaskTokenization, TextScore, }; use dotenv::dotenv; use image::ImageFormat; @@ -376,3 +377,59 @@ async fn batch_semanitc_embed_with_luminous_base() { // There should be 2 embeddings assert_eq!(embeddings.len(), 2); } + +#[tokio::test] +async fn tokenization_with_luminous_base() { + // Given + let input = "Hello, World!"; + + let client = Client::new(&AA_API_TOKEN).unwrap(); + + // When + let task1 = TaskTokenization::new(input, false, true); + let task2 = TaskTokenization::new(input, true, false); + + let response1 = client + .tokenize(&task1, "luminous-base", &How::default()) + .await + .unwrap(); + + let response2 = client + .tokenize(&task2, "luminous-base", &How::default()) + .await + .unwrap(); + + // Then + assert_eq!(response1.tokens, None); + assert_eq!(response1.token_ids, Some(vec![49222, 15, 5390, 4])); + + assert_eq!(response2.token_ids, None); + assert_eq!( + response2.tokens, + Some( + vec!["ĠHello", ",", "ĠWorld", "!"] + .into_iter() + .map(str::to_owned) + .collect() + ) + ); +} + +#[tokio::test] +async fn detokenization_with_luminous_base() { + // Given + let input = vec![49222, 15, 5390, 4]; + + let client = Client::new(&AA_API_TOKEN).unwrap(); + + // When + let task = TaskDetokenization { token_ids: &input }; + + let response = client + .detokenize(&task, "luminous-base", &How::default()) + .await + .unwrap(); + + // Then + assert!(response.result.contains("Hello, World!")); +} diff --git a/tests/unit.rs b/tests/unit.rs index 8fa8f86..28a9725 100644 --- a/tests/unit.rs +++ b/tests/unit.rs @@ -45,10 +45,10 @@ async fn completion_with_luminous_base() { assert_eq!("\n", actual) } -/// If we open too many requests at once, we may trigger rate limmiting. We want this scenario to be +/// If we open too many requests at once, we may trigger rate limiting. We want this scenario to be /// easily detectible by the user, so he/she/it can start sending requests slower. #[tokio::test] -async fn detect_rate_limmiting() { +async fn detect_rate_limiting() { // Given // Start a background HTTP server on a random local part @@ -84,7 +84,7 @@ async fn detect_rate_limmiting() { assert!(matches!(error, Error::TooManyRequests)); } -/// Even if we do not open too many requests at once ourselfes, the API may just be busy. We also +/// Even if we do not open too many requests at once ourselves, the API may just be busy. We also /// want this scenario to be easily detectable by users. #[tokio::test] async fn detect_queue_full() {