Skip to content

Commit

Permalink
add batch semantic embedding task
Browse files Browse the repository at this point in the history
  • Loading branch information
ahartel committed Aug 4, 2023
1 parent 9b597c7 commit 5c2150d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 6 deletions.
15 changes: 13 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ mod semantic_embedding;
use std::time::Duration;

use http::HttpClient;
use semantic_embedding::SemanticEmbeddingOutput;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};

pub use self::{
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
Expand All @@ -43,7 +43,9 @@ pub use self::{
},
http::{Error, Job, Task},
prompt::{Modality, Prompt},
semantic_embedding::{SemanticRepresentation, TaskSemanticEmbedding},
semantic_embedding::{
SemanticRepresentation, TaskBatchSemanticEmbedding, TaskSemanticEmbedding,
},
};

/// Execute Jobs against the Aleph Alpha API
Expand Down Expand Up @@ -122,6 +124,15 @@ impl Client {
self.http_client.output_of(task, how).await
}

/// An batch of embeddings trying to capture the semantic meaning of a text.
pub async fn batch_semantic_embedding(
&self,
task: &TaskBatchSemanticEmbedding<'_>,
how: &How,
) -> Result<BatchSemanticEmbeddingOutput, Error> {
self.http_client.output_of(task, how).await
}

/// Instruct a model served by the aleph alpha API to continue writing a piece of text (or
/// multimodal document).
///
Expand Down
53 changes: 51 additions & 2 deletions src/semantic_embedding.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
use std::fmt::Debug;

use crate::{http::Task, Job, Prompt};

Expand Down Expand Up @@ -42,13 +43,14 @@ pub struct TaskSemanticEmbedding<'a> {
}

/// Appends model and hosting to the bare task
/// T stands for TaskSemanticEmbedding or TaskBatchSemanticEmbedding
#[derive(Serialize, Debug)]
struct RequestBody<'a> {
struct RequestBody<'a, T: Serialize + Debug> {
/// Currently semantic embedding still requires a model parameter, even though "luminous-base"
/// is the only model to support it. This makes Semantic embedding both a Service and a Method.
model: &'a str,
#[serde(flatten)]
semantic_embedding_task: &'a TaskSemanticEmbedding<'a>,
semantic_embedding_task: &'a T,
}

/// Heap allocated embedding. Can hold full embeddings or compressed ones
Expand Down Expand Up @@ -96,3 +98,50 @@ impl Job for TaskSemanticEmbedding<'_> {
response
}
}

/// Create embeddings for multiple prompts
#[derive(Serialize, Debug)]
pub struct TaskBatchSemanticEmbedding<'a> {
/// The prompt (usually text) to be embedded.
pub prompts: Vec<Prompt<'a>>,
/// Semantic representation to embed the prompt with. This parameter is governed by the specific
/// usecase in mind.
pub representation: SemanticRepresentation,
/// Default behaviour is to return the full embedding, but you can optionally request an
/// embedding compressed to a smaller set of dimensions. A size of `128` is supported for every
/// model.
///
/// The 128 size is expected to have a small drop in accuracy performance (4-6%), with the
/// benefit of being much smaller, which makes comparing these embeddings much faster for use
/// cases where speed is critical.
///
/// The 128 size can also perform better if you are embedding short texts or documents.
#[serde(skip_serializing_if = "Option::is_none")]
pub compress_to_size: Option<u32>,
}

/// Heap allocated vec of embeddings. Can hold full embeddings or compressed ones
#[derive(Deserialize)]
pub struct BatchSemanticEmbeddingOutput {
pub embeddings: Vec<Vec<f32>>,
}

impl Job for TaskBatchSemanticEmbedding<'_> {
type Output = BatchSemanticEmbeddingOutput;
type ResponseBody = BatchSemanticEmbeddingOutput;

fn build_request(&self, client: &reqwest::Client, base: &str) -> reqwest::RequestBuilder {
let model = "luminous-base";
let body = RequestBody {
model,
semantic_embedding_task: self,
};
client
.post(format!("{base}/batch_semantic_embed"))
.json(&body)
}

fn body_to_output(&self, response: Self::ResponseBody) -> Self::Output {
response
}
}
39 changes: 37 additions & 2 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{fs::File, io::BufReader};

use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, TaskCompletion,
TaskExplanation, TaskSemanticEmbedding, TextScore,
PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskCompletion, TaskExplanation, TaskSemanticEmbedding, TextScore,
};
use dotenv::dotenv;
use image::ImageFormat;
Expand Down Expand Up @@ -341,3 +341,38 @@ async fn answer_should_continue() {
assert!(response.completion.starts_with(" Says."));
assert!(response.completion.len() > " Says.".len());
}

#[tokio::test]
async fn batch_semanitc_embed_with_luminous_base() {
// Given
let robot_fact = Prompt::from_text(
"A robot is a machine—especially one programmable by a computer—capable of carrying out a \
complex series of actions automatically.",
);
let pizza_fact = Prompt::from_text(
"Pizza (Italian: [ˈpittsa], Neapolitan: [ˈpittsə]) is a dish of Italian origin consisting \
of a usually round, flat base of leavened wheat-based dough topped with tomatoes, cheese, \
and often various other ingredients (such as various types of sausage, anchovies, \
mushrooms, onions, olives, vegetables, meat, ham, etc.), which is then baked at a high \
temperature, traditionally in a wood-fired oven.",
);

let client = Client::new(&AA_API_TOKEN).unwrap();

// When
let embedding_task = TaskBatchSemanticEmbedding {
prompts: vec![robot_fact, pizza_fact],
representation: SemanticRepresentation::Document,
compress_to_size: Some(128),
};

let embeddings = client
.batch_semantic_embedding(&embedding_task, &How::default())
.await
.unwrap()
.embeddings;

// Then
// There should be 2 embeddings
assert_eq!(embeddings.len(), 2);
}

0 comments on commit 5c2150d

Please sign in to comment.