Skip to content

Commit

Permalink
Clean up the de-/tokinization impls
Browse files Browse the repository at this point in the history
  • Loading branch information
andreaskoepf committed Nov 30, 2023
1 parent beb9d21 commit a08bb32
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 29 deletions.
25 changes: 16 additions & 9 deletions src/detokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@ use serde::{Deserialize, Serialize};
/// Input for a [crate::Client::detokenize] request.
pub struct TaskDetokenization<'a> {
/// List of token ids which should be detokenized into text.
pub token_ids: &'a Vec<u32>,
pub token_ids: &'a [u32],
}

/// Body send to the Aleph Alpha API on the POST `/detokenize` route
#[derive(Serialize, Debug)]
struct DetokenizationRequest<'a> {
struct BodyDetokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// List of ids to detokenize.
pub token_ids: &'a Vec<u32>,
pub token_ids: &'a [u32],
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct DetokenizationResponse {
pub struct ResponseDetokenization {
pub result: String,
}

Expand All @@ -25,26 +26,32 @@ pub struct DetokenizationOutput {
pub result: String,
}

impl From<ResponseDetokenization> for DetokenizationOutput {
fn from(response: ResponseDetokenization) -> Self {
Self {
result: response.result,
}
}
}

impl<'a> Task for TaskDetokenization<'a> {
type Output = DetokenizationOutput;
type ResponseBody = DetokenizationResponse;
type ResponseBody = ResponseDetokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = DetokenizationRequest {
let body = BodyDetokenization {
model,
token_ids: &self.token_ids,
};
client.post(format!("{base}/detokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
DetokenizationOutput {
result: response.result,
}
DetokenizationOutput::from(response)
}
}
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl Client {

pub async fn tokenize(
&self,
task: &TaskTokenization,
task: &TaskTokenization<'_>,
model: &str,
how: &How,
) -> Result<TokenizationOutput, Error> {
Expand Down
34 changes: 17 additions & 17 deletions src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::tokenize] request.
pub struct TaskTokenization {
pub struct TaskTokenization<'a> {
/// The text prompt which should be converted into tokens
pub prompt: String,
pub prompt: &'a str,

/// Specify `true` to return text-tokens.
pub tokens: bool,
Expand All @@ -13,40 +13,40 @@ pub struct TaskTokenization {
pub token_ids: bool,
}

impl From<&str> for TaskTokenization {
fn from(s: &str) -> TaskTokenization {
impl<'a> From<&'a str> for TaskTokenization<'a> {
fn from(prompt: &'a str) -> TaskTokenization {
TaskTokenization {
prompt: s.into(),
prompt,
tokens: true,
token_ids: true,
}
}
}

impl TaskTokenization {
pub fn create(prompt: String, tokens: bool, token_ids: bool) -> TaskTokenization {
impl TaskTokenization<'_> {
pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization {
TaskTokenization {
prompt: prompt,
prompt,
tokens,
token_ids,
}
}
}

#[derive(Serialize, Debug)]
struct TokenizationRequest<'a> {
struct BodyTokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base`.
pub model: &'a str,
/// String to tokenize.
pub prompt: &'a String,
pub prompt: &'a str,
/// Set this value to `true` to return text-tokens.
pub tokens: bool,
/// Set this value to `true to return numeric token-ids.
/// Set this value to `true` to return numeric token-ids.
pub token_ids: bool,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct TokenizationResponse {
pub struct ResponseTokenization {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}
Expand All @@ -57,26 +57,26 @@ pub struct TokenizationOutput {
pub token_ids: Option<Vec<u32>>,
}

impl From<TokenizationResponse> for TokenizationOutput {
fn from(response: TokenizationResponse) -> Self {
impl From<ResponseTokenization> for TokenizationOutput {
fn from(response: ResponseTokenization) -> Self {
Self {
tokens: response.tokens,
token_ids: response.token_ids,
}
}
}

impl Task for TaskTokenization {
impl Task for TaskTokenization<'_> {
type Output = TokenizationOutput;
type ResponseBody = TokenizationResponse;
type ResponseBody = ResponseTokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = TokenizationRequest {
let body = BodyTokenization {
model,
prompt: &self.prompt,
tokens: self.tokens,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ async fn tokenization_with_luminous_base() {
let client = Client::new(&AA_API_TOKEN).unwrap();

// When
let task1 = TaskTokenization::create(input.to_owned(), false, true);
let task2 = TaskTokenization::create(input.to_owned(), true, false);
let task1 = TaskTokenization::new(input, false, true);
let task2 = TaskTokenization::new(input, true, false);

let response1 = client
.tokenize(&task1, "luminous-base", &How::default())
Expand Down

0 comments on commit a08bb32

Please sign in to comment.