Skip to content

Commit

Permalink
feat: add chat completion task
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Oct 24, 2024
1 parent a2ad4fb commit 963e2a6
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 28 deletions.
18 changes: 16 additions & 2 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;
use crate::{stream::TaskStreamChat, Task};

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
Expand Down Expand Up @@ -95,6 +95,11 @@ impl<'a> TaskChat<'a> {
self.top_p = Some(top_p);
self
}

/// Creates a wrapper `TaskStreamChat` for this TaskChat.
pub fn with_streaming(self) -> TaskStreamChat<'a> {
TaskStreamChat { task: self }
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
Expand All @@ -109,7 +114,7 @@ pub struct ResponseChat {
}

#[derive(Serialize)]
struct ChatBody<'a> {
pub struct ChatBody<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// The list of messages comprising the conversation so far.
Expand All @@ -126,6 +131,9 @@ struct ChatBody<'a> {
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
/// Whether to stream the response or not.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}

impl<'a> ChatBody<'a> {
Expand All @@ -136,8 +144,14 @@ impl<'a> ChatBody<'a> {
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
stream: false,
}
}

pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}

impl<'a> Task for TaskChat<'a> {
Expand Down
15 changes: 8 additions & 7 deletions src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{stream::parse_stream_event, How};
/// [`Task::with_model`].
pub trait Job {
/// Output returned by [`crate::Client::output_of`]
type Output;
type Output: Send;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de> + Send;
Expand All @@ -35,7 +35,7 @@ pub trait Job {
/// can be executed.
pub trait Task {
/// Output returned by [`crate::Client::output_of`]
type Output;
type Output: Send;

/// Expected answer of the Aleph Alpha API
type ResponseBody: for<'de> Deserialize<'de> + Send;
Expand Down Expand Up @@ -77,7 +77,7 @@ where
self.task.build_request(client, base, self.model)
}

fn body_to_output(response: Self::ResponseBody) -> Self::Output {
fn body_to_output(response: T::ResponseBody) -> T::Output {
T::body_to_output(response)
}
}
Expand Down Expand Up @@ -169,21 +169,22 @@ impl HttpClient {
&self,
task: &T,
how: &How,
) -> Result<mpsc::Receiver<Result<T::ResponseBody, Error>>, Error>
) -> Result<mpsc::Receiver<Result<T::Output, Error>>, Error>
where
T::ResponseBody: Send + 'static,
T::Output: 'static,
{
let response = self.request(task, how).await?;
let mut stream = response.bytes_stream();

let (tx, rx) = mpsc::channel::<Result<T::ResponseBody, Error>>(100);
let (tx, rx) = mpsc::channel::<Result<T::Output, Error>>(100);
tokio::spawn(async move {
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let events = parse_stream_event::<T::ResponseBody>(bytes.as_ref());
for event in events {
tx.send(event).await.unwrap();
let output = event.map(|b| T::body_to_output(b));
tx.send(output).await.unwrap();
}
}
Err(e) => {
Expand Down
24 changes: 20 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use std::time::Duration;

use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use stream::ChatStreamChunk;
use tokenizers::Tokenizer;
use tokio::sync::mpsc;

Expand All @@ -53,7 +54,10 @@ pub use self::{
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
stream::{CompletionSummary, Event, StreamChunk, StreamSummary, TaskStreamCompletion},
stream::{
ChatEvent, CompletionEvent, CompletionSummary, StreamChunk, StreamSummary, TaskStreamChat,
TaskStreamCompletion,
},
tokenization::{TaskTokenization, TokenizationOutput},
};

Expand Down Expand Up @@ -197,7 +201,7 @@ impl Client {
/// Stream the response as a series of events.
///
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskCompletion, Error, Event};
/// use aleph_alpha_client::{Client, How, TaskCompletion, Error, CompletionEvent};
/// async fn print_stream_completion() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
/// let client = Client::with_authentication("AA_API_TOKEN")?;
Expand All @@ -213,7 +217,7 @@ impl Client {
/// // Retrieve stream from API
/// let mut response = client.stream_completion(&task, model, &How::default()).await?;
/// while let Some(Ok(event)) = response.recv().await {
/// if let Event::StreamChunk(chunk) = event {
/// if let CompletionEvent::StreamChunk(chunk) = event {
/// println!("{}", chunk.completion);
/// }
/// }
Expand All @@ -225,7 +229,7 @@ impl Client {
task: &TaskStreamCompletion<'_>,
model: &str,
how: &How,
) -> Result<mpsc::Receiver<Result<Event, Error>>, Error> {
) -> Result<mpsc::Receiver<Result<CompletionEvent, Error>>, Error> {
self.http_client
.stream_output_of(&task.with_model(model), how)
.await
Expand Down Expand Up @@ -265,6 +269,18 @@ impl Client {
.await
}

/// Send a chat message to a model. Stream the response as a series of events.
pub async fn stream_chat(
&self,
task: &TaskStreamChat<'_>,
model: &str,
how: &How,
) -> Result<mpsc::Receiver<Result<ChatStreamChunk, Error>>, Error> {
self.http_client
.stream_output_of(&task.with_model(model), how)
.await
}

/// Returns an explanation given a prompt and a target (typically generated
/// by a previous completion request). The explanation describes how individual parts
/// of the prompt influenced the target.
Expand Down
110 changes: 101 additions & 9 deletions src/stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::{completion::BodyCompletion, http::Task, Error, TaskCompletion};
use crate::{
chat::ChatBody, completion::BodyCompletion, http::Task, Error, TaskChat, TaskCompletion,
};
use serde::Deserialize;

/// Describes a chunk of a completion stream
Expand Down Expand Up @@ -40,20 +42,21 @@ pub struct CompletionSummary {
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Event {
pub enum CompletionEvent {
StreamChunk(StreamChunk),
StreamSummary(StreamSummary),
CompletionSummary(CompletionSummary),
}

/// Wrap a completion task and support streaming.
pub struct TaskStreamCompletion<'a> {
pub task: TaskCompletion<'a>,
}

impl Task for TaskStreamCompletion<'_> {
type Output = Event;
type Output = CompletionEvent;

type ResponseBody = Event;
type ResponseBody = CompletionEvent;

fn build_request(
&self,
Expand All @@ -70,6 +73,65 @@ impl Task for TaskStreamCompletion<'_> {
}
}

#[derive(Deserialize, Debug)]
pub struct Message {
/// The role of the current chat completion. Will be assistant for the first chunk of every
/// completion stream and missing for the remaining chunks.
pub role: Option<String>,
/// The content of the current chat completion. Will be empty for the first chunk of every
/// completion stream and non-empty for the remaining chunks.
pub content: String,
}

/// One chunk of a chat completion stream.
#[derive(Deserialize, Debug)]
pub struct ChatStreamChunk {
/// The reason the model stopped generating tokens.
/// The value is only set in the last chunk of a completion and null otherwise.
pub finish_reason: Option<String>,
/// Chat completion chunk generated by the model when streaming is enabled.
pub delta: Message,
}

/// Event received from a chat completion stream. As the crate does not support multiple
/// chat completions, there will always exactly one choice item.
#[derive(Deserialize, Debug)]
pub struct ChatEvent {
pub choices: Vec<ChatStreamChunk>,
}

/// Wrap a chat task and support streaming.
pub struct TaskStreamChat<'a> {
pub task: TaskChat<'a>,
}

impl<'a> Task for TaskStreamChat<'a> {
type Output = ChatStreamChunk;

type ResponseBody = ChatEvent;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, &self.task).with_streaming();
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(mut response: Self::ResponseBody) -> Self::Output {
// We always expect there to be exactly one choice, as the `n` parameter is not
// supported by this crate.
response
.choices
.pop()
.expect("There must always be at least one choice")
}
}

/// Take a byte slice (of a SSE) and parse it into a provided response body.
/// Each SSE event is expected to contain one or multiple JSON bodies prefixed by `data: `.
pub fn parse_stream_event<ResponseBody>(bytes: &[u8]) -> Vec<Result<ResponseBody, Error>>
where
ResponseBody: for<'de> Deserialize<'de>,
Expand All @@ -96,12 +158,12 @@ mod tests {
let bytes = b"data: {\"type\":\"stream_chunk\",\"index\":0,\"completion\":\" The New York Times, May 15\"}\n\n";

// When they are parsed
let events = parse_stream_event::<Event>(bytes);
let events = parse_stream_event::<CompletionEvent>(bytes);
let event = events.first().unwrap().as_ref().unwrap();

// Then the event is a stream chunk
match event {
Event::StreamChunk(chunk) => assert_eq!(chunk.index, 0),
CompletionEvent::StreamChunk(chunk) => assert_eq!(chunk.index, 0),
_ => panic!("Expected a stream chunk"),
}
}
Expand All @@ -112,18 +174,48 @@ mod tests {
let bytes = b"data: {\"type\":\"stream_summary\",\"index\":0,\"model_version\":\"2022-04\",\"finish_reason\":\"maximum_tokens\"}\n\ndata: {\"type\":\"completion_summary\",\"num_tokens_prompt_total\":1,\"num_tokens_generated\":7}\n\n";

// When they are parsed
let events = parse_stream_event::<Event>(bytes);
let events = parse_stream_event::<CompletionEvent>(bytes);

// Then the first event is a stream summary and the last event is a completion summary
let first = events.first().unwrap().as_ref().unwrap();
match first {
Event::StreamSummary(summary) => assert_eq!(summary.finish_reason, "maximum_tokens"),
CompletionEvent::StreamSummary(summary) => {
assert_eq!(summary.finish_reason, "maximum_tokens")
}
_ => panic!("Expected a completion summary"),
}
let second = events.last().unwrap().as_ref().unwrap();
match second {
Event::CompletionSummary(summary) => assert_eq!(summary.num_tokens_generated, 7),
CompletionEvent::CompletionSummary(summary) => {
assert_eq!(summary.num_tokens_generated, 7)
}
_ => panic!("Expected a completion summary"),
}
}

#[test]
fn chat_stream_chunk_event_is_parsed() {
// Given some bytes
let bytes = b"data: {\"id\":\"831e41b4-2382-4b08-990e-0a3859967f43\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":\"\"},\"logprobs\":null}],\"created\":1729782822,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";

// When they are parsed
let events = parse_stream_event::<ChatEvent>(bytes);
let event = events.first().unwrap().as_ref().unwrap();

// Then the event is a chat stream chunk
assert_eq!(event.choices[0].delta.role.as_ref().unwrap(), "assistant");
}

#[test]
fn chat_stream_chunk_without_role_is_parsed() {
// Given some bytes without a role
let bytes = b"data: {\"id\":\"a3ceca7f-32b2-4a6c-89e7-bc8eb5327f76\",\"choices\":[{\"finish_reason\":null,\"index\":0,\"delta\":{\"content\":\"Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.\"},\"logprobs\":null}],\"created\":1729784197,\"model\":\"pharia-1-llm-7b-control\",\"system_fingerprint\":null,\"object\":\"chat.completion.chunk\",\"usage\":null}\n\n";

// When they are parsed
let events = parse_stream_event::<ChatEvent>(bytes);
let event = events.first().unwrap().as_ref().unwrap();

// Then the event is a chat stream chunk
assert_eq!(event.choices[0].delta.content, "Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.");
}
}
Loading

0 comments on commit 963e2a6

Please sign in to comment.