Skip to content

Commit

Permalink
feat: expose chat method on client
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Oct 22, 2024
1 parent 959338a commit cf7d03e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
12 changes: 6 additions & 6 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,14 @@ impl<'a> TaskChat<'a> {
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Choice<'a> {
pub message: Message<'a>,
pub struct Choice {
pub message: Message<'static>,
pub finish_reason: String,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ChatResponse<'a> {
pub choices: Vec<Choice<'a>>,
pub struct ChatResponse {
pub choices: Vec<Choice>,
}
#[derive(Serialize)]
struct ChatBody<'a> {
Expand Down Expand Up @@ -125,9 +125,9 @@ impl<'a> ChatBody<'a> {
}

impl<'a> Task for TaskChat<'a> {
type Output = Choice<'a>;
type Output = Choice;

type ResponseBody = ChatResponse<'a>;
type ResponseBody = ChatResponse;

fn build_request(
&self,
Expand Down
27 changes: 18 additions & 9 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,17 +165,25 @@ impl HttpClient {
auth_value
}

pub async fn tokenizer_by_model(&self, model: &str, api_token: Option<String> ) -> Result<Tokenizer, Error> {
pub async fn tokenizer_by_model(
&self,
model: &str,
api_token: Option<String>,
) -> Result<Tokenizer, Error> {
let api_token = api_token
.as_ref()
.or(self.api_token.as_ref())
.expect("API token needs to be set on client construction or per request");
let response = self.http.get(format!("{}/models/{model}/tokenizer", self.base))
.header(header::AUTHORIZATION, Self::header_from_token(api_token)).send().await?;
let response = self
.http
.get(format!("{}/models/{model}/tokenizer", self.base))
.header(header::AUTHORIZATION, Self::header_from_token(api_token))
.send()
.await?;
let response = translate_http_error(response).await?;
let bytes = response.bytes().await?;
let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| {
Error::InvalidTokenizer { deserialization_error: e.to_string() }
let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer {
deserialization_error: e.to_string(),
})?;
Ok(tokenizer)
}
Expand Down Expand Up @@ -251,10 +259,11 @@ pub enum Error {
/// An error on the Http Protocol level.
#[error("HTTP request failed with status code {}. Body:\n{}", status, body)]
Http { status: u16, body: String },
#[error("Tokenizer could not be correctly deserialized. Caused by:\n{}", deserialization_error)]
InvalidTokenizer {
deserialization_error: String,
},
#[error(
"Tokenizer could not be correctly deserialized. Caused by:\n{}",
deserialization_error
)]
InvalidTokenizer { deserialization_error: String },
/// Most likely either TLS errors creating the Client, or IO errors.
#[error(transparent)]
Other(#[from] reqwest::Error),
Expand Down
29 changes: 29 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod semantic_embedding;
mod tokenization;
use std::time::Duration;

use chat::Choice;
use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -190,6 +191,34 @@ impl Client {
.await
}

/// Send a chat message to a model.
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskChat, Error};
///
/// async fn chat() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
///
/// // Name of a model that supports chat.
/// let model = "pharia-1-llm-7b-control";
///
/// // Create a chat task with a user message.
/// let task = TaskChat::new(Role::User, "Hello, how are you?");
///
/// // Send the message to the model.
/// let response = client.chat(&task, model, &How::default()).await?;
///
/// // Print the model response
/// println!("{}", response.message.content);
/// Ok(())
/// }
/// ```
pub async fn chat(&self, task: &TaskChat<'_>, model: &str, how: &How) -> Result<Choice, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
}

/// Returns an explanation given a prompt and a target (typically generated
/// by a previous completion request). The explanation describes how individual parts
/// of the prompt influenced the target.
Expand Down
8 changes: 1 addition & 7 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@ async fn chat_with_pharia_1_7b_base() {

let model = "pharia-1-llm-7b-control";
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();

eprintln!("{:?}", response.message);
let response = client.chat(&task, model, &How::default()).await.unwrap();

// Then
assert!(!response.message.content.is_empty())
Expand Down Expand Up @@ -557,4 +552,3 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {
// Then
assert_eq!(128_000, tokenizer.get_vocab_size(true));
}

0 comments on commit cf7d03e

Please sign in to comment.