From 5c2150d825d2af42b1f19a2689f767eb87cdf65a Mon Sep 17 00:00:00 2001 From: Andreas Hartel Date: Fri, 4 Aug 2023 15:41:07 +0200 Subject: [PATCH] add batch semantic embedding task --- src/lib.rs | 15 +++++++++-- src/semantic_embedding.rs | 53 +++++++++++++++++++++++++++++++++++++-- tests/integration.rs | 39 ++++++++++++++++++++++++++-- 3 files changed, 101 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8135c95..c42927a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,7 +33,7 @@ mod semantic_embedding; use std::time::Duration; use http::HttpClient; -use semantic_embedding::SemanticEmbeddingOutput; +use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; pub use self::{ completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, @@ -43,7 +43,9 @@ pub use self::{ }, http::{Error, Job, Task}, prompt::{Modality, Prompt}, - semantic_embedding::{SemanticRepresentation, TaskSemanticEmbedding}, + semantic_embedding::{ + SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding, + }, }; /// Execute Jobs against the Aleph Alpha API @@ -122,6 +124,15 @@ impl Client { self.http_client.output_of(task, how).await } + /// An batch of embeddings trying to capture the semantic meaning of a text. + pub async fn batch_semantic_embedding( + &self, + task: &TaskBatchSemanticEmbedding<'_>, + how: &How, + ) -> Result { + self.http_client.output_of(task, how).await + } + /// Instruct a model served by the aleph alpha API to continue writing a piece of text (or /// multimodal document). /// diff --git a/src/semantic_embedding.rs b/src/semantic_embedding.rs index 5b7739f..c64501f 100644 --- a/src/semantic_embedding.rs +++ b/src/semantic_embedding.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use crate::{http::Task, Job, Prompt}; @@ -42,13 +43,14 @@ pub struct TaskSemanticEmbedding<'a> { } /// Appends model and hosting to the bare task +/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding #[derive(Serialize, Debug)] -struct RequestBody<'a> { +struct RequestBody<'a, T: Serialize + Debug> { /// Currently semantic embedding still requires a model parameter, even though "luminous-base" /// is the only model to support it. This makes Semantic embedding both a Service and a Method. model: &'a str, #[serde(flatten)] - semantic_embedding_task: &'a TaskSemanticEmbedding<'a>, + semantic_embedding_task: &'a T, } /// Heap allocated embedding. Can hold full embeddings or compressed ones @@ -96,3 +98,50 @@ impl Job for TaskSemanticEmbedding<'_> { response } } + +/// Create embeddings for multiple prompts +#[derive(Serialize, Debug)] +pub struct TaskBatchSemanticEmbedding<'a> { + /// The prompt (usually text) to be embedded. + pub prompts: Vec>, + /// Semantic representation to embed the prompt with. This parameter is governed by the specific + /// usecase in mind. + pub representation: SemanticRepresentation, + /// Default behaviour is to return the full embedding, but you can optionally request an + /// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every + /// model. + /// + /// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the + /// benefit of being much smaller, which makes comparing these embeddings much faster for use + /// cases where speed is critical. + /// + /// The 128 size can also perform better if you are embedding short texts or documents. + #[serde(skip_serializing_if = "Option::is_none")] + pub compress_to_size: Option, +} + +/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones +#[derive(Deserialize)] +pub struct BatchSemanticEmbeddingOutput { + pub embeddings: Vec>, +} + +impl Job for TaskBatchSemanticEmbedding<'_> { + type Output = BatchSemanticEmbeddingOutput; + type ResponseBody = BatchSemanticEmbeddingOutput; + + fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder { + let model = "luminous-base"; + let body = RequestBody { + model, + semantic_embedding_task: self, + }; + client + .post(format!("{base}/batch_semantic_embed")) + .json(&body) + } + + fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output { + response + } +} diff --git a/tests/integration.rs b/tests/integration.rs index 228222b..9972bec 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -2,8 +2,8 @@ use std::{fs::File, io::BufReader}; use aleph_alpha_client::{ cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt, - PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, TaskCompletion, - TaskExplanation, TaskSemanticEmbedding, TextScore, + PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, + TaskBatchSemanticEmbedding, TaskCompletion, TaskExplanation, TaskSemanticEmbedding, TextScore, }; use dotenv::dotenv; use image::ImageFormat; @@ -341,3 +341,38 @@ async fn answer_should_continue() { assert!(response.completion.starts_with(" Says.")); assert!(response.completion.len() > " Says.".len()); } + +#[tokio::test] +async fn batch_semanitc_embed_with_luminous_base() { + // Given + let robot_fact = Prompt::from_text( + "A robot is a machine—especially one programmable by a computer—capable of carrying out a \ + complex series of actions automatically.", + ); + let pizza_fact = Prompt::from_text( + "Pizza (Italian: [ˈpittsa], Neapolitan: [ˈpittsə]) is a dish of Italian origin consisting \ + of a usually round, flat base of leavened wheat-based dough topped with tomatoes, cheese, \ + and often various other ingredients (such as various types of sausage, anchovies, \ + mushrooms, onions, olives, vegetables, meat, ham, etc.), which is then baked at a high \ + temperature, traditionally in a wood-fired oven.", + ); + + let client = Client::new(&AA_API_TOKEN).unwrap(); + + // When + let embedding_task = TaskBatchSemanticEmbedding { + prompts: vec![robot_fact, pizza_fact], + representation: SemanticRepresentation::Document, + compress_to_size: Some(128), + }; + + let embeddings = client + .batch_semantic_embedding(&embedding_task, &How::default()) + .await + .unwrap() + .embeddings; + + // Then + // There should be 2 embeddings + assert_eq!(embeddings.len(), 2); +}