From cf7d03e09523d3c6c7d3a46c078468973eeae546 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Tue, 22 Oct 2024 14:00:04 +0200 Subject: [PATCH] feat: expose chat method on client --- src/chat.rs | 12 ++++++------ src/http.rs | 27 ++++++++++++++++++--------- src/lib.rs | 29 +++++++++++++++++++++++++++++ tests/integration.rs | 8 +------- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index 360c74f..d2e537e 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -83,14 +83,14 @@ impl<'a> TaskChat<'a> { } #[derive(Deserialize, Debug, PartialEq, Eq)] -pub struct Choice<'a> { - pub message: Message<'a>, +pub struct Choice { + pub message: Message<'static>, pub finish_reason: String, } #[derive(Deserialize, Debug, PartialEq, Eq)] -pub struct ChatResponse<'a> { - pub choices: Vec>, +pub struct ChatResponse { + pub choices: Vec, } #[derive(Serialize)] struct ChatBody<'a> { @@ -125,9 +125,9 @@ impl<'a> ChatBody<'a> { } impl<'a> Task for TaskChat<'a> { - type Output = Choice<'a>; + type Output = Choice; - type ResponseBody = ChatResponse<'a>; + type ResponseBody = ChatResponse; fn build_request( &self, diff --git a/src/http.rs b/src/http.rs index f2c27e6..fd4b62a 100644 --- a/src/http.rs +++ b/src/http.rs @@ -165,17 +165,25 @@ impl HttpClient { auth_value } - pub async fn tokenizer_by_model(&self, model: &str, api_token: Option ) -> Result { + pub async fn tokenizer_by_model( + &self, + model: &str, + api_token: Option, + ) -> Result { let api_token = api_token .as_ref() .or(self.api_token.as_ref()) .expect("API token needs to be set on client construction or per request"); - let response = self.http.get(format!("{}/models/{model}/tokenizer", self.base)) - .header(header::AUTHORIZATION, Self::header_from_token(api_token)).send().await?; + let response = self + .http + .get(format!("{}/models/{model}/tokenizer", self.base)) + .header(header::AUTHORIZATION, Self::header_from_token(api_token)) + .send() + .await?; let response = translate_http_error(response).await?; let bytes = response.bytes().await?; - let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| { - Error::InvalidTokenizer { deserialization_error: e.to_string() } + let tokenizer = Tokenizer::from_bytes(bytes).map_err(|e| Error::InvalidTokenizer { + deserialization_error: e.to_string(), })?; Ok(tokenizer) } @@ -251,10 +259,11 @@ pub enum Error { /// An error on the Http Protocol level. #[error("HTTP request failed with status code {}. Body:\n{}", status, body)] Http { status: u16, body: String }, - #[error("Tokenizer could not be correctly deserialized. Caused by:\n{}", deserialization_error)] - InvalidTokenizer { - deserialization_error: String, - }, + #[error( + "Tokenizer could not be correctly deserialized. Caused by:\n{}", + deserialization_error + )] + InvalidTokenizer { deserialization_error: String }, /// Most likely either TLS errors creating the Client, or IO errors. #[error(transparent)] Other(#[from] reqwest::Error), diff --git a/src/lib.rs b/src/lib.rs index 044251f..a21728a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,7 @@ mod semantic_embedding; mod tokenization; use std::time::Duration; +use chat::Choice; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; use tokenizers::Tokenizer; @@ -190,6 +191,34 @@ impl Client { .await } + /// Send a chat message to a model. + /// ```no_run + /// use aleph_alpha_client::{Client, How, TaskChat, Error}; + /// + /// async fn chat() -> Result<(), Error> { + /// // Authenticate against API. Fetches token. + /// let client = Client::with_authentication("AA_API_TOKEN")?; + /// + /// // Name of a model that supports chat. + /// let model = "pharia-1-llm-7b-control"; + /// + /// // Create a chat task with a user message. + /// let task = TaskChat::new(Role::User, "Hello, how are you?"); + /// + /// // Send the message to the model. + /// let response = client.chat(&task, model, &How::default()).await?; + /// + /// // Print the model response + /// println!("{}", response.message.content); + /// Ok(()) + /// } + /// ``` + pub async fn chat(&self, task: &TaskChat<'_>, model: &str, how: &How) -> Result { + self.http_client + .output_of(&task.with_model(model), how) + .await + } + /// Returns an explanation given a prompt and a target (typically generated /// by a previous completion request). The explanation describes how individual parts /// of the prompt influenced the target. diff --git a/tests/integration.rs b/tests/integration.rs index 33a76b1..7743c89 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -25,12 +25,7 @@ async fn chat_with_pharia_1_7b_base() { let model = "pharia-1-llm-7b-control"; let client = Client::with_authentication(api_token()).unwrap(); - let response = client - .output_of(&task.with_model(model), &How::default()) - .await - .unwrap(); - - eprintln!("{:?}", response.message); + let response = client.chat(&task, model, &How::default()).await.unwrap(); // Then assert!(!response.message.content.is_empty()) @@ -557,4 +552,3 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Then assert_eq!(128_000, tokenizer.get_vocab_size(true)); } -