Skip to content

Commit

Permalink
feat!: do not hide role behind enum
Browse files Browse the repository at this point in the history
  • Loading branch information
moldhouse committed Oct 24, 2024
1 parent 0408f6a commit a869887
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 26 deletions.
42 changes: 23 additions & 19 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,30 @@ use serde::{Deserialize, Serialize};

use crate::Task;

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
pub role: Role,
pub role: Cow<'a, str>,
pub content: Cow<'a, str>,
}

impl<'a> Message<'a> {
pub fn new(role: impl Into<Cow<'a, str>>, content: impl Into<Cow<'a, str>>) -> Self {
Self {
role: role.into(),
content: content.into(),
}
}
pub fn user(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("assistant", content)
}
pub fn system(content: impl Into<Cow<'a, str>>) -> Self {
Self::new("system", content)
}
}

pub struct TaskChat<'a> {
/// The list of messages comprising the conversation so far.
pub messages: Vec<Message<'a>>,
Expand All @@ -42,24 +52,18 @@ pub struct TaskChat<'a> {
impl<'a> TaskChat<'a> {
/// Creates a new TaskChat containing one message with the given role and content.
/// All optional TaskChat attributes are left unset.
pub fn new(role: Role, content: impl Into<Cow<'a, str>>) -> Self {
pub fn with_message(message: Message<'a>) -> Self {
TaskChat {
messages: vec![Message {
role,
content: content.into(),
}],
messages: vec![message],
maximum_tokens: None,
temperature: None,
top_p: None,
}
}

/// Pushes a new Message to this TaskChat.
pub fn append_message(mut self, role: Role, content: impl Into<Cow<'a, str>>) -> Self {
self.messages.push(Message {
role,
content: content.into(),
});
pub fn push_message(mut self, message: Message<'a>) -> Self {
self.messages.push(message);
self
}

Expand Down
14 changes: 10 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;

pub use self::{
chat::{ChatOutput, Message, Role, TaskChat},
chat::{ChatOutput, Message, TaskChat},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Expand Down Expand Up @@ -192,7 +192,7 @@ impl Client {

/// Send a chat message to a model.
/// ```no_run
/// use aleph_alpha_client::{Client, How, TaskChat, Error, Role};
/// use aleph_alpha_client::{Client, How, TaskChat, Error, Message};
///
/// async fn chat() -> Result<(), Error> {
/// // Authenticate against API. Fetches token.
Expand All @@ -202,7 +202,8 @@ impl Client {
/// let model = "pharia-1-llm-7b-control";
///
/// // Create a chat task with a user message.
/// let task = TaskChat::new(Role::User, "Hello, how are you?");
/// let message = Message::user("Hello, how are you?");
/// let task = TaskChat::with_message(message);
///
/// // Send the message to the model.
/// let response = client.chat(&task, model, &How::default()).await?;
Expand All @@ -212,7 +213,12 @@ impl Client {
/// Ok(())
/// }
/// ```
pub async fn chat(&self, task: &TaskChat<'_>, model: &str, how: &How) -> Result<ChatOutput, Error> {
pub async fn chat<'a>(
&'a self,
task: &'a TaskChat<'a>,
model: &'a str,
how: &'a How,
) -> Result<ChatOutput, Error> {
self.http_client
.output_of(&task.with_model(model), how)
.await
Expand Down
7 changes: 4 additions & 3 deletions tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::{fs::File, io::BufReader, sync::OnceLock};

use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
PromptGranularity, Role, Sampling, SemanticRepresentation, Stopping, Task,
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Message, Modality,
Prompt, PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation,
TaskSemanticEmbedding, TaskTokenization, TextScore,
};
Expand All @@ -21,7 +21,8 @@ fn api_token() -> &'static str {
#[tokio::test]
async fn chat_with_pharia_1_7b_base() {
// When
let task = TaskChat::new(Role::System, "Instructions").append_message(Role::User, "Question");
let message = Message::user("Question");
let task = TaskChat::with_message(message);

let model = "pharia-1-llm-7b-control";
let client = Client::with_authentication(api_token()).unwrap();
Expand Down

0 comments on commit a869887

Please sign in to comment.