From 9dd95579b90edb218c736f4ab4e994293da97d5b Mon Sep 17 00:00:00 2001 From: Andreas Koepf Date: Thu, 30 Nov 2023 11:46:06 +0100 Subject: [PATCH] Clean up the de-/tokinization impls --- src/detokenization.rs | 25 ++++++++++++++++--------- src/lib.rs | 2 +- src/tokenization.rs | 36 ++++++++++++++++++------------------ tests/integration.rs | 4 ++-- 4 files changed, 37 insertions(+), 30 deletions(-) diff --git a/src/detokenization.rs b/src/detokenization.rs index b1f4cd8..5e3e64d 100644 --- a/src/detokenization.rs +++ b/src/detokenization.rs @@ -4,19 +4,20 @@ use serde::{Deserialize, Serialize}; /// Input for a [crate::Client::detokenize] request. pub struct TaskDetokenization<'a> { /// List of token ids which should be detokenized into text. - pub token_ids: &'a Vec, + pub token_ids: &'a [u32], } +/// Body send to the Aleph Alpha API on the POST `/detokenize` route #[derive(Serialize, Debug)] -struct DetokenizationRequest<'a> { +struct BodyDetokenization<'a> { /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`. pub model: &'a str, /// List of ids to detokenize. - pub token_ids: &'a Vec, + pub token_ids: &'a [u32], } #[derive(Deserialize, Debug, PartialEq, Eq)] -pub struct DetokenizationResponse { +pub struct ResponseDetokenization { pub result: String, } @@ -25,9 +26,17 @@ pub struct DetokenizationOutput { pub result: String, } +impl From for DetokenizationOutput { + fn from(response: ResponseDetokenization) -> Self { + Self { + result: response.result, + } + } +} + impl<'a> Task for TaskDetokenization<'a> { type Output = DetokenizationOutput; - type ResponseBody = DetokenizationResponse; + type ResponseBody = ResponseDetokenization; fn build_request( &self, @@ -35,7 +44,7 @@ impl<'a> Task for TaskDetokenization<'a> { base: &str, model: &str, ) -> reqwest::RequestBuilder { - let body = DetokenizationRequest { + let body = BodyDetokenization { model, token_ids: &self.token_ids, }; @@ -43,8 +52,6 @@ impl<'a> Task for TaskDetokenization<'a> { } fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output { - DetokenizationOutput { - result: response.result, - } + DetokenizationOutput::from(response) } } diff --git a/src/lib.rs b/src/lib.rs index 82601a2..6607391 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -222,7 +222,7 @@ impl Client { pub async fn tokenize( &self, - task: &TaskTokenization, + task: &TaskTokenization<'_>, model: &str, how: &How, ) -> Result { diff --git a/src/tokenization.rs b/src/tokenization.rs index cf0c069..e4fc2a2 100644 --- a/src/tokenization.rs +++ b/src/tokenization.rs @@ -2,31 +2,31 @@ use crate::Task; use serde::{Deserialize, Serialize}; /// Input for a [crate::Client::tokenize] request. -pub struct TaskTokenization { +pub struct TaskTokenization<'a> { /// The text prompt which should be converted into tokens - pub prompt: String, + pub prompt: &'a str, /// Specify `true` to return text-tokens. pub tokens: bool, - /// Specify `true to return numeric token-ids. + /// Specify `true` to return numeric token-ids. pub token_ids: bool, } -impl From<&str> for TaskTokenization { - fn from(s: &str) -> TaskTokenization { +impl<'a> From<&'a str> for TaskTokenization<'a> { + fn from(prompt: &'a str) -> TaskTokenization { TaskTokenization { - prompt: s.into(), + prompt, tokens: true, token_ids: true, } } } -impl TaskTokenization { - pub fn create(prompt: String, tokens: bool, token_ids: bool) -> TaskTokenization { +impl TaskTokenization<'_> { + pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization { TaskTokenization { - prompt: prompt, + prompt, tokens, token_ids, } @@ -34,19 +34,19 @@ impl TaskTokenization { } #[derive(Serialize, Debug)] -struct TokenizationRequest<'a> { +struct BodyTokenization<'a> { /// Name of the model tasked with completing the prompt. E.g. `luminous-base`. pub model: &'a str, /// String to tokenize. - pub prompt: &'a String, + pub prompt: &'a str, /// Set this value to `true` to return text-tokens. pub tokens: bool, - /// Set this value to `true to return numeric token-ids. + /// Set this value to `true` to return numeric token-ids. pub token_ids: bool, } #[derive(Deserialize, Debug, PartialEq, Eq)] -pub struct TokenizationResponse { +pub struct ResponseTokenization { pub tokens: Option>, pub token_ids: Option>, } @@ -57,8 +57,8 @@ pub struct TokenizationOutput { pub token_ids: Option>, } -impl From for TokenizationOutput { - fn from(response: TokenizationResponse) -> Self { +impl From for TokenizationOutput { + fn from(response: ResponseTokenization) -> Self { Self { tokens: response.tokens, token_ids: response.token_ids, @@ -66,9 +66,9 @@ impl From for TokenizationOutput { } } -impl Task for TaskTokenization { +impl Task for TaskTokenization<'_> { type Output = TokenizationOutput; - type ResponseBody = TokenizationResponse; + type ResponseBody = ResponseTokenization; fn build_request( &self, @@ -76,7 +76,7 @@ impl Task for TaskTokenization { base: &str, model: &str, ) -> reqwest::RequestBuilder { - let body = TokenizationRequest { + let body = BodyTokenization { model, prompt: &self.prompt, tokens: self.tokens, diff --git a/tests/integration.rs b/tests/integration.rs index 7d86190..64c5fa7 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -386,8 +386,8 @@ async fn tokenization_with_luminous_base() { let client = Client::new(&AA_API_TOKEN).unwrap(); // When - let task1 = TaskTokenization::create(input.to_owned(), false, true); - let task2 = TaskTokenization::create(input.to_owned(), true, false); + let task1 = TaskTokenization::new(input, false, true); + let task2 = TaskTokenization::new(input, true, false); let response1 = client .tokenize(&task1, "luminous-base", &How::default())