Skip to content

Commit

Permalink
Simplified retry logic
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Oct 18, 2023
1 parent 0cea6e4 commit 0d7e7db
Showing 1 changed file with 26 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl UnstructuredParser {
let text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size);
let task = tokio::spawn(async move {
let generator = RemoteEmbeddingGenerator::new_default();
Self::generate_text_group_embeddings(&text_groups, &generator, 5, max_chunk_size).await
Self::generate_text_group_embeddings(&text_groups, &generator, 50, max_chunk_size).await
});
let new_text_groups = task.await??;

Expand Down Expand Up @@ -181,7 +181,7 @@ impl UnstructuredParser {
let mut doc = DocumentVectorResource::new_empty(name, resource_desc.as_deref(), source.clone(), &resource_id);
doc.set_embedding_model_used(generator.model_type());

let keywords = UnstructuredParser::extract_keywords(&text_groups, 50);
let keywords = UnstructuredParser::extract_keywords(&text_groups, 30);
doc.update_resource_embedding_blocking(generator, keywords)?;

for grouped_text in &text_groups {
Expand Down Expand Up @@ -262,6 +262,7 @@ impl UnstructuredParser {

/// Recursively goes through all of the text groups and batch generates embeddings
/// for all of them.
#[async_recursion]
pub async fn generate_text_group_embeddings(
text_groups: &Vec<GroupedText>,
generator: &dyn EmbeddingGenerator,
Expand All @@ -279,45 +280,37 @@ impl UnstructuredParser {
// Generate embeddings for all texts in batches
let ids: Vec<String> = vec!["".to_string(); texts.len()];
let mut embeddings = Vec::new();

let mut start = 0;
'outer: loop {
for batch in texts[start..].chunks(max_batch_size as usize) {
let batch_ids = &ids[start..start + batch.len()];
println!("Generating batched embeddings for {} text groups", batch_ids.len());

match generator
.generate_embeddings(&batch.to_vec(), &batch_ids.to_vec())
.await
{
Ok(batch_embeddings) => {
embeddings.extend(batch_embeddings);
start += batch.len();
}
Err(e) => {
println!("Generating batched embeddings failed for {:?}", batch);
println!("Error generating embeddings: {:?}", e);
if let VectorResourceError::RequestFailed(_) = e {
if max_batch_size > 1 {
max_batch_size -= 1;
continue 'outer;
} else {
return Err(e);
}
} else {
return Err(e);
}
for batch in texts.chunks(max_batch_size as usize) {
let batch_ids = &ids[..batch.len()];
println!("Generating batched embeddings for {} text groups", batch_ids.len());
match generator
.generate_embeddings(&batch.to_vec(), &batch_ids.to_vec())
.await
{
Ok(batch_embeddings) => {
embeddings.extend(batch_embeddings);
}
Err(e) => {
println!("Error generating embeddings: {:?}", e);
if max_batch_size > 5 {
max_batch_size -= 5;
return Self::generate_text_group_embeddings(
&text_groups,
generator,
max_batch_size,
max_chunk_size,
)
.await;
} else {
return Err(e);
}
}
}
break;
}

println!("All embeddings generated. Total: {}", embeddings.len());
// Assign the generated embeddings back to the text groups and their subgroups
Self::assign_embeddings(&mut text_groups, &mut embeddings, &indices);

println!("All embeddings set.");
Ok(text_groups)
}

Expand Down

0 comments on commit 0d7e7db

Please sign in to comment.