Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tokenize & detokenize to client, fix typos #8

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub struct Stopping<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
pub stop_sequences: &'a [&'a str],
Expand All @@ -95,7 +95,7 @@ impl<'a> Stopping<'a> {
/// Body send to the Aleph Alpha API on the POST `/completion` Route
#[derive(Serialize, Debug)]
struct BodyCompletion<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminus-base`.
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// Prompt to complete. The modalities supported depend on `model`.
pub prompt: Prompt<'a>,
Expand All @@ -104,7 +104,7 @@ struct BodyCompletion<'a> {
/// List of strings which will stop generation if they are generated. Stop sequences are
/// helpful in structured texts. E.g.: In a question answering scenario a text may consist of
/// lines starting with either "Question: " or "Answer: " (alternating). After producing an
/// answer, the model will be likely to generate "Question: ". "Question: " may therfore be used
/// answer, the model will be likely to generate "Question: ". "Question: " may therefore be used
/// as stop sequence in order not to have the model generate more questions but rather restrict
/// text generation to the answers.
#[serde(skip_serializing_if = "<[_]>::is_empty")]
Expand Down
57 changes: 57 additions & 0 deletions src/detokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::detokenize] request.
pub struct TaskDetokenization<'a> {
/// List of token ids which should be detokenized into text.
pub token_ids: &'a [u32],
}

/// Body send to the Aleph Alpha API on the POST `/detokenize` route
#[derive(Serialize, Debug)]
struct BodyDetokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// List of ids to detokenize.
pub token_ids: &'a [u32],
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseDetokenization {
pub result: String,
}

#[derive(Debug, PartialEq, Eq)]
pub struct DetokenizationOutput {
pub result: String,
}

impl From<ResponseDetokenization> for DetokenizationOutput {
fn from(response: ResponseDetokenization) -> Self {
Self {
result: response.result,
}
}
}

impl<'a> Task for TaskDetokenization<'a> {
type Output = DetokenizationOutput;
type ResponseBody = ResponseDetokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyDetokenization {
model,
token_ids: &self.token_ids,
};
client.post(format!("{base}/detokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
DetokenizationOutput::from(response)
}
}
2 changes: 1 addition & 1 deletion src/explanation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct TaskExplanation<'a> {
/// The target string that should be explained. The influence of individual parts
/// of the prompt for generating this target string will be indicated in the response.
pub target: &'a str,
/// Granularity paramaters for the explanation
/// Granularity parameters for the explanation
pub granularity: Granularity,
}

Expand Down
16 changes: 8 additions & 8 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::How;
/// for the Aleph Alpha API to specify its result. Notably it includes the model(s) the job is
/// executed on. This allows this trait to hold in the presence of services, which use more than one
/// model and task type to achieve their result. On the other hand a bare [`crate::TaskCompletion`]
/// can not implement this trait directly, since its result would depend on what model is choosen to
/// execute it. You can remidy this by turning completion task into a job, calling
/// can not implement this trait directly, since its result would depend on what model is chosen to
/// execute it. You can remedy this by turning completion task into a job, calling
/// [`Task::with_model`].
pub trait Job {
/// Output returned by [`crate::Client::output_of`]
Expand Down Expand Up @@ -130,7 +130,7 @@ impl HttpClient {
let query = if how.be_nice {
[("nice", "true")].as_slice()
} else {
// nice=false is default, so we just ommit it.
// nice=false is default, so we just omit it.
[].as_slice()
};
let response = task
Expand All @@ -156,7 +156,7 @@ impl HttpClient {
async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Response, Error> {
let status = response.status();
if !status.is_success() {
// Store body in a variable, so we can use it, even if it is not an Error emmitted by
// Store body in a variable, so we can use it, even if it is not an Error emitted by
// the API, but an intermediate Proxy like NGinx, so we can still forward the error
// message.
let body = response.text().await?;
Expand All @@ -174,14 +174,14 @@ async fn translate_http_error(response: reqwest::Response) -> Result<reqwest::Re
}
}

/// We are only interessted in the status codes of the API.
/// We are only interested in the status codes of the API.
#[derive(Deserialize)]
struct ApiError<'a> {
/// Unique string in capital letters emitted by the API to signal different kinds of errors in a
/// finer granualrity then the HTTP status codes alone would allow for.
/// finer granularity then the HTTP status codes alone would allow for.
///
/// E.g. Differentiating between request rate limiting and parallel tasks limiting which both
/// are 429 (the former is emmited by NGinx though).
/// are 429 (the former is emitted by NGinx though).
_code: Cow<'a, str>,
}

Expand All @@ -204,7 +204,7 @@ pub enum Error {
Busy,
#[error("No response received within given timeout: {0:?}")]
ClientTimeout(Duration),
/// An error on the Http Protocl level.
/// An error on the Http Protocol level.
#[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
Http { status: u16, body: String },
/// Most likely either TLS errors creating the Client, or IO errors.
Expand Down
76 changes: 75 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
//! ```

mod completion;
mod detokenization;
mod explanation;
mod http;
mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod tokenization;

use std::time::Duration;

Expand All @@ -37,6 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};

pub use self::{
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Explanation, ExplanationOutput, Granularity, ImageScore, ItemExplanation,
PromptGranularity, TaskExplanation, TextScore,
Expand All @@ -46,6 +49,7 @@ pub use self::{
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
tokenization::{TaskTokenization, TokenizationOutput},
};

/// Execute Jobs against the Aleph Alpha API
Expand Down Expand Up @@ -215,6 +219,76 @@ impl Client {
.output_of(&task.with_model(model), how)
.await
}

/// Tokenize a prompt for a specific model.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskTokenization};
///
/// async fn tokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Name of the model for which we want to tokenize text.
/// let model = "luminous-base";
///
/// // Text prompt to be tokenized.
/// let prompt = "An apple a day";
///
/// let task = TaskTokenization {
/// prompt,
/// tokens: true, // return text-tokens
/// token_ids: true, // return numeric token-ids
/// };
/// let respones = client.tokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn tokenize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A minimal example with a comment why you want to call this would be nice. Not required though for me to merge the PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I added docstring examples for tokenize()/detokenize().

&self,
task: &TaskTokenization<'_>,
model: &str,
how: &How,
) -> Result<TokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}

/// Detokenize a list of token ids into a string.
///
/// ```no_run
/// use aleph_alpha_client::{Client, Error, How, TaskDetokenization};
///
/// async fn detokenize() -> Result<(), Error> {
/// let client = Client::new(AA_API_TOKEN)?;
///
/// // Specify the name of the model whose tokenizer was used to generate the input token ids.
/// let model = "luminous-base";
///
/// // Token ids to convert into text.
/// let token_ids: Vec<u32> = vec![556, 48741, 247, 2983];
///
/// let task = TaskDetokenization {
/// token_ids: &token_ids,
/// };
/// let respones = client.detokenize(&task, model, &How::default()).await?;
///
/// dbg!(&respones);
/// Ok(())
/// }
/// ```
pub async fn detokenize(
&self,
task: &TaskDetokenization<'_>,
model: &str,
how: &How,
) -> Result<DetokenizationOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}
}

/// Controls of how to execute a task
Expand Down Expand Up @@ -254,7 +328,7 @@ impl Default for How {
/// Client, Prompt, TaskSemanticEmbedding, cosine_similarity, SemanticRepresentation, How
/// };
///
/// async fn semanitc_search_with_luminous_base(client: &Client) {
/// async fn semantic_search_with_luminous_base(client: &Client) {
/// // Given
/// let robot_fact = Prompt::from_text(
/// "A robot is a machine—especially one programmable by a computer—capable of carrying out a \
Expand Down
91 changes: 91 additions & 0 deletions src/tokenization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use crate::Task;
use serde::{Deserialize, Serialize};

/// Input for a [crate::Client::tokenize] request.
pub struct TaskTokenization<'a> {
/// The text prompt which should be converted into tokens
pub prompt: &'a str,

/// Specify `true` to return text-tokens.
pub tokens: bool,

/// Specify `true` to return numeric token-ids.
pub token_ids: bool,
}

impl<'a> From<&'a str> for TaskTokenization<'a> {
fn from(prompt: &'a str) -> TaskTokenization {
TaskTokenization {
prompt,
tokens: true,
token_ids: true,
}
}
}

impl TaskTokenization<'_> {
pub fn new(prompt: &str, tokens: bool, token_ids: bool) -> TaskTokenization {
TaskTokenization {
prompt,
tokens,
token_ids,
}
}
}

#[derive(Serialize, Debug)]
struct BodyTokenization<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base`.
pub model: &'a str,
/// String to tokenize.
pub prompt: &'a str,
/// Set this value to `true` to return text-tokens.
pub tokens: bool,
/// Set this value to `true` to return numeric token-ids.
pub token_ids: bool,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ResponseTokenization {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

#[derive(Debug, PartialEq)]
pub struct TokenizationOutput {
pub tokens: Option<Vec<String>>,
pub token_ids: Option<Vec<u32>>,
}

impl From<ResponseTokenization> for TokenizationOutput {
fn from(response: ResponseTokenization) -> Self {
Self {
tokens: response.tokens,
token_ids: response.token_ids,
}
}
}

impl Task for TaskTokenization<'_> {
type Output = TokenizationOutput;
type ResponseBody = ResponseTokenization;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = BodyTokenization {
model,
prompt: &self.prompt,
tokens: self.tokens,
token_ids: self.token_ids,
};
client.post(format!("{base}/tokenize")).json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
TokenizationOutput::from(response)
}
}
Loading
Loading