Skip to content

Commit

Permalink
Added batch embed gen async/blocking
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Oct 17, 2023
1 parent e0d2c79 commit 5cbe11e
Showing 1 changed file with 108 additions and 53 deletions.
161 changes: 108 additions & 53 deletions shinkai-libs/shinkai-vector-resources/src/embedding_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ pub trait EmbeddingGenerator: Sync {
self.generate_embedding_blocking(input_string, "")
}

/// Generates embeddings from the given list of input strings and ids.
/// Note: This is a blocking method.
fn generate_embeddings_blocking(
&self,
input_strings: &Vec<String>,
ids: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError>;

/// Generate Embeddings for a list of input strings, sets ids to default.
/// Note: This is a blocking method.
fn generate_embeddings_blocking_default(
&self,
input_strings: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError> {
let ids: Vec<String> = vec!["".to_string(); input_strings.len()];
self.generate_embeddings_blocking(input_strings, &ids)
}

/// Generates an embedding from the given input string, and assigns the
/// provided id.
async fn generate_embedding(&self, input_string: &str, id: &str) -> Result<Embedding, VectorResourceError>;
Expand All @@ -44,22 +62,21 @@ pub trait EmbeddingGenerator: Sync {
self.generate_embedding(input_string, "").await
}

// /// Generates embeddings from the given list of input strings and ids.
// fn generate_embeddings(&self, input_strings: &[&str], ids: &[&str]) -> Result<Vec<Embedding>, VectorResourceError> {
// input_strings
// .iter()
// .zip(ids)
// .map(|(input, id)| self.generate_embedding(input, id))
// .collect()
// }

// /// Generate Embeddings for a list of input strings, sets ids to default
// fn generate_embeddings_default(&self, input_strings: &[&str]) -> Result<Vec<Embedding>, VectorResourceError> {
// input_strings
// .iter()
// .map(|input| self.generate_embedding_default_blocking(input))
// .collect()
// }
/// Generates embeddings from the given list of input strings and ids.
async fn generate_embeddings(
&self,
input_strings: &Vec<String>,
ids: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError>;

/// Generate Embeddings for a list of input strings, sets ids to default
async fn generate_embeddings_default(
&self,
input_strings: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError> {
let ids: Vec<String> = vec!["".to_string(); input_strings.len()];
self.generate_embeddings(input_strings, &ids).await
}
}

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -73,56 +90,94 @@ pub struct RemoteEmbeddingGenerator {
#[cfg(feature = "native-http")]
#[async_trait]
impl EmbeddingGenerator for RemoteEmbeddingGenerator {
/// Generate Embeddings for an input list of strings by using the external API.
/// This method batch generates whenever possible to increase speed.
/// Note this method is blocking.
fn generate_embeddings_blocking(
&self,
input_strings: &Vec<String>,
ids: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError> {
match self.model_type {
EmbeddingModelType::BertCPP(_) => {
let mut embeddings = Vec::new();
for (input_string, id) in input_strings.iter().zip(ids) {
let vector = self.generate_embedding_bert_cpp_blocking(input_string)?;
embeddings.push(Embedding::new(id, vector));
}
Ok(embeddings)
}
EmbeddingModelType::TextEmbeddingsInference(_) => {
self.generate_embedding_tei_blocking(input_strings.clone(), ids.clone())
}
_ => {
let mut embeddings = Vec::new();
for (input_string, id) in input_strings.iter().zip(ids) {
let embedding = self.generate_embedding_open_ai_blocking(input_string, id)?;
embeddings.push(embedding);
}
Ok(embeddings)
}
}
}

/// Generate an Embedding for an input string by using the external API.
/// Note this method is blocking.
fn generate_embedding_blocking(&self, input_string: &str, id: &str) -> Result<Embedding, VectorResourceError> {
// If we're using a Bert model with a Bert-CPP server
if let EmbeddingModelType::BertCPP(_) = self.model_type {
let vector = self.generate_embedding_bert_cpp_blocking(input_string)?;
return Ok(Embedding::new(id, vector));
let input_strings = vec![input_string.to_string()];
let ids = vec![id.to_string()];

let results = self.generate_embeddings_blocking(&input_strings, &ids)?;
if results.is_empty() {
Err(VectorResourceError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
))
} else {
Ok(results[0].clone())
}
// We're using hugging face TextEmbeddingsInference
if let EmbeddingModelType::TextEmbeddingsInference(_) = self.model_type {
let results = self.generate_embedding_tei_blocking(vec![input_string.to_string()], vec![id.to_string()])?;
if results.len() < 1 {
return Err(VectorResourceError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
));
}

/// Generate an Embedding for an input string by using the external API.
/// This method batch generates whenever possible to increase speed.
async fn generate_embeddings(
&self,
input_strings: &Vec<String>,
ids: &Vec<String>,
) -> Result<Vec<Embedding>, VectorResourceError> {
match self.model_type {
EmbeddingModelType::BertCPP(_) => Err(VectorResourceError::FailedEmbeddingGeneration(
"BertCPP support does not include async operation".to_string(),
)),
EmbeddingModelType::TextEmbeddingsInference(_) => {
self.generate_embedding_tei(input_strings.clone(), ids.clone()).await
}
_ => {
let mut embeddings = Vec::new();
for (input_string, id) in input_strings.iter().zip(ids) {
let embedding = self.generate_embedding_open_ai(input_string, id).await?;
embeddings.push(embedding);
}
Ok(embeddings)
}
return Ok(results[0].clone());
}
// Else we're using OpenAI API
else {
return self.generate_embedding_open_ai_blocking(input_string, id);
}
}

/// Generate an Embedding for an input string by using the external API.
async fn generate_embedding(&self, input_string: &str, id: &str) -> Result<Embedding, VectorResourceError> {
// If we're using a Bert model with a Bert-CPP server
if let EmbeddingModelType::BertCPP(_) = self.model_type {
return Err(VectorResourceError::FailedEmbeddingGeneration(
"BertCPP model does not support async operation".to_string(),
));
}
// We're using hugging face TextEmbeddingsInference
else if let EmbeddingModelType::TextEmbeddingsInference(_) = self.model_type {
let results = self
.generate_embedding_tei(vec![input_string.to_string()], vec![id.to_string()])
.await?;
if results.len() < 1 {
return Err(VectorResourceError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
));
}
return Ok(results[0].clone());
}
// Else we're using OpenAI API
else {
return self.generate_embedding_open_ai(input_string, id).await;
let input_strings = vec![input_string.to_string()];
let ids = vec![id.to_string()];

let results = self.generate_embeddings(&input_strings, &ids).await?;
if results.is_empty() {
Err(VectorResourceError::FailedEmbeddingGeneration(
"No results returned from the embedding generation".to_string(),
))
} else {
Ok(results[0].clone())
}
}

/// Returns the EmbeddingModelType
fn model_type(&self) -> EmbeddingModelType {
self.model_type.clone()
}
Expand Down

0 comments on commit 5cbe11e

Please sign in to comment.