diff --git a/src/chat.rs b/src/chat.rs index 4f4af8e..b4a29fd 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -4,20 +4,30 @@ use serde::{Deserialize, Serialize}; use crate::Task; -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum Role { - System, - User, - Assistant, -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Message<'a> { - pub role: Role, + pub role: Cow<'a, str>, pub content: Cow<'a, str>, } +impl<'a> Message<'a> { + pub fn new(role: impl Into>, content: impl Into>) -> Self { + Self { + role: role.into(), + content: content.into(), + } + } + pub fn user(content: impl Into>) -> Self { + Self::new("user", content) + } + pub fn assistant(content: impl Into>) -> Self { + Self::new("assistant", content) + } + pub fn system(content: impl Into>) -> Self { + Self::new("system", content) + } +} + pub struct TaskChat<'a> { /// The list of messages comprising the conversation so far. pub messages: Vec>, @@ -42,12 +52,20 @@ pub struct TaskChat<'a> { impl<'a> TaskChat<'a> { /// Creates a new TaskChat containing one message with the given role and content. /// All optional TaskChat attributes are left unset. - pub fn new(role: Role, content: impl Into>) -> Self { + pub fn with_message(message: Message<'a>) -> Self { + TaskChat { + messages: vec![message], + maximum_tokens: None, + temperature: None, + top_p: None, + } + } + + /// Creates a new TaskChat containing the given messages. + /// All optional TaskChat attributes are left unset. + pub fn with_messages(messages: Vec>) -> Self { TaskChat { - messages: vec![Message { - role, - content: content.into(), - }], + messages, maximum_tokens: None, temperature: None, top_p: None, @@ -55,11 +73,8 @@ impl<'a> TaskChat<'a> { } /// Pushes a new Message to this TaskChat. - pub fn append_message(mut self, role: Role, content: impl Into>) -> Self { - self.messages.push(Message { - role, - content: content.into(), - }); + pub fn push_message(mut self, message: Message<'a>) -> Self { + self.messages.push(message); self } diff --git a/src/lib.rs b/src/lib.rs index 57575fe..edc0283 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; use tokenizers::Tokenizer; pub use self::{ - chat::{ChatOutput, Message, Role, TaskChat}, + chat::{ChatOutput, Message, TaskChat}, completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, detokenization::{DetokenizationOutput, TaskDetokenization}, explanation::{ @@ -192,7 +192,7 @@ impl Client { /// Send a chat message to a model. /// ```no_run - /// use aleph_alpha_client::{Client, How, TaskChat, Error, Role}; + /// use aleph_alpha_client::{Client, How, TaskChat, Error, Message}; /// /// async fn chat() -> Result<(), Error> { /// // Authenticate against API. Fetches token. @@ -202,7 +202,8 @@ impl Client { /// let model = "pharia-1-llm-7b-control"; /// /// // Create a chat task with a user message. - /// let task = TaskChat::new(Role::User, "Hello, how are you?"); + /// let message = Message::user("Hello, how are you?"); + /// let task = TaskChat::with_message(message); /// /// // Send the message to the model. /// let response = client.chat(&task, model, &How::default()).await?; @@ -212,7 +213,12 @@ impl Client { /// Ok(()) /// } /// ``` - pub async fn chat(&self, task: &TaskChat<'_>, model: &str, how: &How) -> Result { + pub async fn chat<'a>( + &'a self, + task: &'a TaskChat<'a>, + model: &'a str, + how: &'a How, + ) -> Result { self.http_client .output_of(&task.with_model(model), how) .await diff --git a/tests/integration.rs b/tests/integration.rs index 7743c89..bb17fed 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,8 +1,8 @@ use std::{fs::File, io::BufReader, sync::OnceLock}; use aleph_alpha_client::{ - cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt, - PromptGranularity, Role, Sampling, SemanticRepresentation, Stopping, Task, + cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Message, Modality, + Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore, }; @@ -21,7 +21,8 @@ fn api_token() -> &'static str { #[tokio::test] async fn chat_with_pharia_1_7b_base() { // When - let task = TaskChat::new(Role::System, "Instructions").append_message(Role::User, "Question"); + let message = Message::user("Question"); + let task = TaskChat::with_message(message); let model = "pharia-1-llm-7b-control"; let client = Client::with_authentication(api_token()).unwrap();