From 963e2a67445ee60d2914989e728988cb3c1972e5 Mon Sep 17 00:00:00 2001 From: Moritz Althaus Date: Thu, 24 Oct 2024 23:20:13 +0200 Subject: [PATCH] feat: add chat completion task --- src/chat.rs | 18 ++++++- src/http.rs | 15 +++--- src/lib.rs | 24 ++++++++-- src/stream.rs | 110 +++++++++++++++++++++++++++++++++++++++---- tests/integration.rs | 44 ++++++++++++++--- 5 files changed, 183 insertions(+), 28 deletions(-) diff --git a/src/chat.rs b/src/chat.rs index 030383d..694721e 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use serde::{Deserialize, Serialize}; -use crate::Task; +use crate::{stream::TaskStreamChat, Task}; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Message<'a> { @@ -95,6 +95,11 @@ impl<'a> TaskChat<'a> { self.top_p = Some(top_p); self } + + /// Creates a wrapper `TaskStreamChat` for this TaskChat. + pub fn with_streaming(self) -> TaskStreamChat<'a> { + TaskStreamChat { task: self } + } } #[derive(Deserialize, Debug, PartialEq, Eq)] @@ -109,7 +114,7 @@ pub struct ResponseChat { } #[derive(Serialize)] -struct ChatBody<'a> { +pub struct ChatBody<'a> { /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`. pub model: &'a str, /// The list of messages comprising the conversation so far. @@ -126,6 +131,9 @@ struct ChatBody<'a> { /// When no value is provided, the default value of 1 will be used. #[serde(skip_serializing_if = "Option::is_none")] pub top_p: Option, + /// Whether to stream the response or not. + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub stream: bool, } impl<'a> ChatBody<'a> { @@ -136,8 +144,14 @@ impl<'a> ChatBody<'a> { maximum_tokens: task.maximum_tokens, temperature: task.temperature, top_p: task.top_p, + stream: false, } } + + pub fn with_streaming(mut self) -> Self { + self.stream = true; + self + } } impl<'a> Task for TaskChat<'a> { diff --git a/src/http.rs b/src/http.rs index ab5d6c4..73f77cb 100644 --- a/src/http.rs +++ b/src/http.rs @@ -18,7 +18,7 @@ use crate::{stream::parse_stream_event, How}; /// [`Task::with_model`]. pub trait Job { /// Output returned by [`crate::Client::output_of`] - type Output; + type Output: Send; /// Expected answer of the Aleph Alpha API type ResponseBody: for<'de> Deserialize<'de> + Send; @@ -35,7 +35,7 @@ pub trait Job { /// can be executed. pub trait Task { /// Output returned by [`crate::Client::output_of`] - type Output; + type Output: Send; /// Expected answer of the Aleph Alpha API type ResponseBody: for<'de> Deserialize<'de> + Send; @@ -77,7 +77,7 @@ where self.task.build_request(client, base, self.model) } - fn body_to_output(response: Self::ResponseBody) -> Self::Output { + fn body_to_output(response: T::ResponseBody) -> T::Output { T::body_to_output(response) } } @@ -169,21 +169,22 @@ impl HttpClient { &self, task: &T, how: &How, - ) -> Result>, Error> + ) -> Result>, Error> where - T::ResponseBody: Send + 'static, + T::Output: 'static, { let response = self.request(task, how).await?; let mut stream = response.bytes_stream(); - let (tx, rx) = mpsc::channel::>(100); + let (tx, rx) = mpsc::channel::>(100); tokio::spawn(async move { while let Some(item) = stream.next().await { match item { Ok(bytes) => { let events = parse_stream_event::(bytes.as_ref()); for event in events { - tx.send(event).await.unwrap(); + let output = event.map(|b| T::body_to_output(b)); + tx.send(output).await.unwrap(); } } Err(e) => { diff --git a/src/lib.rs b/src/lib.rs index 9ccff2d..d829895 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -37,6 +37,7 @@ use std::time::Duration; use http::HttpClient; use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; +use stream::ChatStreamChunk; use tokenizers::Tokenizer; use tokio::sync::mpsc; @@ -53,7 +54,10 @@ pub use self::{ semantic_embedding::{ SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding, }, - stream::{CompletionSummary, Event, StreamChunk, StreamSummary, TaskStreamCompletion}, + stream::{ + ChatEvent, CompletionEvent, CompletionSummary, StreamChunk, StreamSummary, TaskStreamChat, + TaskStreamCompletion, + }, tokenization::{TaskTokenization, TokenizationOutput}, }; @@ -197,7 +201,7 @@ impl Client { /// Stream the response as a series of events. /// /// ```no_run - /// use aleph_alpha_client::{Client, How, TaskCompletion, Error, Event}; + /// use aleph_alpha_client::{Client, How, TaskCompletion, Error, CompletionEvent}; /// async fn print_stream_completion() -> Result<(), Error> { /// // Authenticate against API. Fetches token. /// let client = Client::with_authentication("AA_API_TOKEN")?; @@ -213,7 +217,7 @@ impl Client { /// // Retrieve stream from API /// let mut response = client.stream_completion(&task, model, &How::default()).await?; /// while let Some(Ok(event)) = response.recv().await { - /// if let Event::StreamChunk(chunk) = event { + /// if let CompletionEvent::StreamChunk(chunk) = event { /// println!("{}", chunk.completion); /// } /// } @@ -225,7 +229,7 @@ impl Client { task: &TaskStreamCompletion<'_>, model: &str, how: &How, - ) -> Result>, Error> { + ) -> Result>, Error> { self.http_client .stream_output_of(&task.with_model(model), how) .await @@ -265,6 +269,18 @@ impl Client { .await } + /// Send a chat message to a model. Stream the response as a series of events. + pub async fn stream_chat( + &self, + task: &TaskStreamChat<'_>, + model: &str, + how: &How, + ) -> Result>, Error> { + self.http_client + .stream_output_of(&task.with_model(model), how) + .await + } + /// Returns an explanation given a prompt and a target (typically generated /// by a previous completion request). The explanation describes how individual parts /// of the prompt influenced the target. diff --git a/src/stream.rs b/src/stream.rs index dcd6b07..839b174 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,4 +1,6 @@ -use crate::{completion::BodyCompletion, http::Task, Error, TaskCompletion}; +use crate::{ + chat::ChatBody, completion::BodyCompletion, http::Task, Error, TaskChat, TaskCompletion, +}; use serde::Deserialize; /// Describes a chunk of a completion stream @@ -40,20 +42,21 @@ pub struct CompletionSummary { #[derive(Deserialize, Debug)] #[serde(tag = "type")] #[serde(rename_all = "snake_case")] -pub enum Event { +pub enum CompletionEvent { StreamChunk(StreamChunk), StreamSummary(StreamSummary), CompletionSummary(CompletionSummary), } +/// Wrap a completion task and support streaming. pub struct TaskStreamCompletion<'a> { pub task: TaskCompletion<'a>, } impl Task for TaskStreamCompletion<'_> { - type Output = Event; + type Output = CompletionEvent; - type ResponseBody = Event; + type ResponseBody = CompletionEvent; fn build_request( &self, @@ -70,6 +73,65 @@ impl Task for TaskStreamCompletion<'_> { } } +#[derive(Deserialize, Debug)] +pub struct Message { + /// The role of the current chat completion. Will be assistant for the first chunk of every + /// completion stream and missing for the remaining chunks. + pub role: Option, + /// The content of the current chat completion. Will be empty for the first chunk of every + /// completion stream and non-empty for the remaining chunks. + pub content: String, +} + +/// One chunk of a chat completion stream. +#[derive(Deserialize, Debug)] +pub struct ChatStreamChunk { + /// The reason the model stopped generating tokens. + /// The value is only set in the last chunk of a completion and null otherwise. + pub finish_reason: Option, + /// Chat completion chunk generated by the model when streaming is enabled. + pub delta: Message, +} + +/// Event received from a chat completion stream. As the crate does not support multiple +/// chat completions, there will always exactly one choice item. +#[derive(Deserialize, Debug)] +pub struct ChatEvent { + pub choices: Vec, +} + +/// Wrap a chat task and support streaming. +pub struct TaskStreamChat<'a> { + pub task: TaskChat<'a>, +} + +impl<'a> Task for TaskStreamChat<'a> { + type Output = ChatStreamChunk; + + type ResponseBody = ChatEvent; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = ChatBody::new(model, &self.task).with_streaming(); + client.post(format!("{base}/chat/completions")).json(&body) + } + + fn body_to_output(mut response: Self::ResponseBody) -> Self::Output { + // We always expect there to be exactly one choice, as the `n` parameter is not + // supported by this crate. + response + .choices + .pop() + .expect("There must always be at least one choice") + } +} + +/// Take a byte slice (of a SSE) and parse it into a provided response body. +/// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `. pub fn parse_stream_event(bytes: &[u8]) -> Vec> where ResponseBody: for<'de> Deserialize<'de>, @@ -96,12 +158,12 @@ mod tests { let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n"; // When they are parsed - let events = parse_stream_event::(bytes); + let events = parse_stream_event::(bytes); let event = events.first().unwrap().as_ref().unwrap(); // Then the event is a stream chunk match event { - Event::StreamChunk(chunk) => assert_eq!(chunk.index, 0), + CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0), _ => panic!("Expected a stream chunk"), } } @@ -112,18 +174,48 @@ mod tests { let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n"; // When they are parsed - let events = parse_stream_event::(bytes); + let events = parse_stream_event::(bytes); // Then the first event is a stream summary and the last event is a completion summary let first = events.first().unwrap().as_ref().unwrap(); match first { - Event::StreamSummary(summary) => assert_eq!(summary.finish_reason, "maximum_tokens"), + CompletionEvent::StreamSummary(summary) => { + assert_eq!(summary.finish_reason, "maximum_tokens") + } _ => panic!("Expected a completion summary"), } let second = events.last().unwrap().as_ref().unwrap(); match second { - Event::CompletionSummary(summary) => assert_eq!(summary.num_tokens_generated, 7), + CompletionEvent::CompletionSummary(summary) => { + assert_eq!(summary.num_tokens_generated, 7) + } _ => panic!("Expected a completion summary"), } } + + #[test] + fn chat_stream_chunk_event_is_parsed() { + // Given some bytes + let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n"; + + // When they are parsed + let events = parse_stream_event::(bytes); + let event = events.first().unwrap().as_ref().unwrap(); + + // Then the event is a chat stream chunk + assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant"); + } + + #[test] + fn chat_stream_chunk_without_role_is_parsed() { + // Given some bytes without a role + let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n"; + + // When they are parsed + let events = parse_stream_event::(bytes); + let event = events.first().unwrap().as_ref().unwrap(); + + // Then the event is a chat stream chunk + assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask."); + } } diff --git a/tests/integration.rs b/tests/integration.rs index f1a1d96..6f676d4 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,8 +1,8 @@ use std::{fs::File, io::BufReader, sync::OnceLock}; use aleph_alpha_client::{ - cosine_similarity, Client, Event, Granularity, How, ImageScore, ItemExplanation, Message, - Modality, Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, + cosine_similarity, Client, CompletionEvent, Granularity, How, ImageScore, ItemExplanation, + Message, Modality, Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore, }; @@ -555,7 +555,7 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { } #[tokio::test] -async fn test_streaming_completion() { +async fn stream_completion() { // Given a streaming completion task let client = Client::with_authentication(api_token()).unwrap(); let task = TaskCompletion::from_text("") @@ -575,10 +575,42 @@ async fn test_streaming_completion() { // Then there are at least one chunk, one summary and one completion summary assert!(events.len() >= 3); - assert!(matches!(events[events.len() - 3], Event::StreamChunk(_))); - assert!(matches!(events[events.len() - 2], Event::StreamSummary(_))); + assert!(matches!( + events[events.len() - 3], + CompletionEvent::StreamChunk(_) + )); + assert!(matches!( + events[events.len() - 2], + CompletionEvent::StreamSummary(_) + )); assert!(matches!( events[events.len() - 1], - Event::CompletionSummary(_) + CompletionEvent::CompletionSummary(_) )); } + +#[tokio::test] +async fn stream_chat_with_pharia_1_llm_7b() { + // Given a streaming completion task + let client = Client::with_authentication(api_token()).unwrap(); + let message = Message::user("Hello,"); + let task = TaskChat::with_messages(vec![message]) + .with_maximum_tokens(7) + .with_streaming(); + + // When the events are streamed and collected + let mut rx = client + .stream_chat(&task, "pharia-1-llm-7b-control", &How::default()) + .await + .unwrap(); + + let mut events = Vec::new(); + while let Some(Ok(event)) = rx.recv().await { + events.push(event); + } + + // Then there are at least two chunks, with the second one having no role + assert!(events.len() >= 2); + assert_eq!(events[0].delta.role.as_ref().unwrap(), "assistant"); + assert_eq!(events[1].delta.role, None); +}