Skip to content

Commit

Permalink
reimplemented pathing logic for embed gen
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Oct 18, 2023
1 parent 0d7e7db commit 37b5e64
Showing 1 changed file with 32 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl UnstructuredParser {
let text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size);
let task = tokio::spawn(async move {
let generator = RemoteEmbeddingGenerator::new_default();
Self::generate_text_group_embeddings(&text_groups, &generator, 50, max_chunk_size).await
Self::generate_text_group_embeddings(&text_groups, &generator, 31, max_chunk_size).await
});
let new_text_groups = task.await??;

Expand Down Expand Up @@ -181,7 +181,7 @@ impl UnstructuredParser {
let mut doc = DocumentVectorResource::new_empty(name, resource_desc.as_deref(), source.clone(), &resource_id);
doc.set_embedding_model_used(generator.model_type());

let keywords = UnstructuredParser::extract_keywords(&text_groups, 30);
let keywords = UnstructuredParser::extract_keywords(&text_groups, 50);
doc.update_resource_embedding_blocking(generator, keywords)?;

for grouped_text in &text_groups {
Expand Down Expand Up @@ -225,18 +225,31 @@ impl UnstructuredParser {
fn collect_texts_and_indices(
text_groups: &[GroupedText],
texts: &mut Vec<String>,
indices: &mut Vec<(usize, Option<usize>)>,
indices: &mut Vec<(Vec<usize>, usize)>,
max_chunk_size: u64,
path: Vec<usize>,
) {
for (i, text_group) in text_groups.iter().enumerate() {
println!("Processing text group at index {}", i);
texts.push(text_group.text.clone());
indices.push((i, None));
let mut current_path = path.clone();
current_path.push(i);
indices.push((current_path.clone(), texts.len() - 1));
for (j, sub_group) in text_group.sub_groups.iter().enumerate() {
println!("Processing sub-group at index {} of text group at index {}", j, i);
texts.push(sub_group.text.clone());
indices.push((i, Some(j)));
Self::collect_texts_and_indices(&sub_group.sub_groups, texts, indices, max_chunk_size);
let mut sub_path = current_path.clone();
sub_path.push(j);
indices.push((sub_path, texts.len() - 1));
}
for sub_group in &text_group.sub_groups {
Self::collect_texts_and_indices(
&sub_group.sub_groups,
texts,
indices,
max_chunk_size,
current_path.clone(),
);
}
}
}
Expand All @@ -245,18 +258,16 @@ impl UnstructuredParser {
fn assign_embeddings(
text_groups: &mut [GroupedText],
embeddings: &mut Vec<Embedding>,
indices: &[(usize, Option<usize>)],
indices: &[(Vec<usize>, usize)],
) {
for (i, text_group) in text_groups.iter_mut().enumerate() {
if let Some((index, sub_index)) = indices.get(i) {
if let Some(embedding) = embeddings.get(*index) {
match sub_index {
Some(j) => text_group.sub_groups[*j].embedding = Some(embedding.clone()),
None => text_group.embedding = Some(embedding.clone()),
}
for (path, flat_index) in indices {
if let Some(embedding) = embeddings.get(*flat_index) {
let mut target = &mut text_groups[path[0]];
for &index in &path[1..] {
target = &mut target.sub_groups[index];
}
target.embedding = Some(embedding.clone());
}
Self::assign_embeddings(&mut text_group.sub_groups, embeddings, indices);
}
}

Expand All @@ -275,7 +286,7 @@ impl UnstructuredParser {
// Collect all texts from the text groups and their subgroups
let mut texts = Vec::new();
let mut indices = Vec::new();
Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size);
Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size, vec![]);

// Generate embeddings for all texts in batches
let ids: Vec<String> = vec!["".to_string(); texts.len()];
Expand Down Expand Up @@ -308,9 +319,14 @@ impl UnstructuredParser {
}
}

println!("Number of texts: {}", texts.len());
println!("Number of indices: {}", indices.len());
println!("Number of embeddings: {}", embeddings.len());

// Assign the generated embeddings back to the text groups and their subgroups
Self::assign_embeddings(&mut text_groups, &mut embeddings, &indices);

println!("Embeddings successfully assigned!");
Ok(text_groups)
}

Expand Down

0 comments on commit 37b5e64

Please sign in to comment.