Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add chat feature #26

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
pub role: Role,
pub content: Cow<'a, str>,
}

pub struct TaskChat<'a> {
/// The list of messages comprising the conversation so far.
pub messages: Vec<Message<'a>>,
/// The maximum number of tokens to be generated. Completion will terminate after the maximum
/// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
/// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
/// and maximum_tokens depends on the model.
/// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
/// The model will generate tokens until it generates one of the specified stop_sequences or it
/// reaches its technical limit, which usually is its context window.
pub maximum_tokens: Option<u32>,
/// A temperature encourages the model to produce less probable outputs ("be more creative").
/// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
/// response.
pub temperature: Option<f64>,
/// Introduces random sampling for generated tokens by randomly selecting the next token from
/// the smallest possible set of tokens whose cumulative probability exceeds the probability
/// top_p. Set to 0 to get the same behaviour as `None`.
pub top_p: Option<f64>,
}

impl<'a> TaskChat<'a> {
/// Creates a new TaskChat containing one message with the given role and content.
/// All optional TaskChat attributes are left unset.
pub fn new(role: Role, content: impl Into<Cow<'a, str>>) -> Self {
TaskChat {
messages: vec![Message {
role,
content: content.into(),
}],
maximum_tokens: None,
temperature: None,
top_p: None,
}
}

/// Pushes a new Message to this TaskChat.
pub fn append_message(mut self, role: Role, content: impl Into<Cow<'a, str>>) -> Self {
self.messages.push(Message {
role,
content: content.into(),
});
self
}

/// Sets the maximum token attribute of this TaskChat.
pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
self.maximum_tokens = Some(maximum_tokens);
self
}

/// Sets the temperature attribute of this TaskChat.
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}

/// Sets the top_p attribute of this TaskChat.
pub fn with_top_p(mut self, top_p: f64) -> Self {
self.top_p = Some(top_p);
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Choice<'a> {
pub message: Message<'a>,
pub finish_reason: String,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ChatResponse<'a> {
pub choices: Vec<Choice<'a>>,
}
#[derive(Serialize)]
struct ChatBody<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// The list of messages comprising the conversation so far.
messages: &'a [Message<'a>],
/// Limits the number of tokens, which are generated for the completion.
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum_tokens: Option<u32>,
/// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
/// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}

impl<'a> ChatBody<'a> {
pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
Self {
model,
messages: &task.messages,
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
}
}
}

impl<'a> Task for TaskChat<'a> {
type Output = Choice<'a>;

type ResponseBody = ChatResponse<'a>;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, self);
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.choices.pop().unwrap()
}
}
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//! }
//! ```

mod chat;
mod completion;
mod detokenization;
mod explanation;
Expand All @@ -31,14 +32,14 @@ mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod tokenization;

use std::time::Duration;

use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;

pub use self::{
chat::{Role, TaskChat},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Expand Down Expand Up @@ -305,7 +306,11 @@ impl Client {
.await
}

pub async fn tokenizer_by_model(&self, model: &str, api_token: Option<String>) -> Result<Tokenizer, Error> {
pub async fn tokenizer_by_model(
&self,
model: &str,
api_token: Option<String>,
) -> Result<Tokenizer, Error> {
self.http_client.tokenizer_by_model(model, api_token).await
}
}
Expand Down
25 changes: 22 additions & 3 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{fs::File, io::BufReader, sync::OnceLock};

use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskCompletion, TaskDetokenization, TaskExplanation,
PromptGranularity, Role, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation,
TaskSemanticEmbedding, TaskTokenization, TextScore,
};
use dotenv::dotenv;
Expand All @@ -18,6 +18,24 @@ fn api_token() -> &'static str {
})
}

#[tokio::test]
async fn chat_with_pharia_1_7b_base() {
// When
let task = TaskChat::new(Role::System, "Instructions").append_message(Role::User, "Question");

let model = "pharia-1-llm-7b-control";
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();

eprintln!("{:?}", response.message);

// Then
assert!(!response.message.content.is_empty())
}

#[tokio::test]
async fn completion_with_luminous_base() {
// When
Expand Down Expand Up @@ -538,4 +556,5 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {

// Then
assert_eq!(128_000, tokenizer.get_vocab_size(true));
}
}

Loading