From 37b5e6461999c6084b2970c05535ac3ef41806b5 Mon Sep 17 00:00:00 2001 From: Robert Kornacki <11645932+robkorn@users.noreply.github.com> Date: Wed, 18 Oct 2023 19:10:12 +0200 Subject: [PATCH] reimplemented pathing logic for embed gen --- .../src/unstructured/unstructured_parser.rs | 48 ++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs index 80d898942..c7c2e8994 100644 --- a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs +++ b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs @@ -90,7 +90,7 @@ impl UnstructuredParser { let text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size); let task = tokio::spawn(async move { let generator = RemoteEmbeddingGenerator::new_default(); - Self::generate_text_group_embeddings(&text_groups, &generator, 50, max_chunk_size).await + Self::generate_text_group_embeddings(&text_groups, &generator, 31, max_chunk_size).await }); let new_text_groups = task.await??; @@ -181,7 +181,7 @@ impl UnstructuredParser { let mut doc = DocumentVectorResource::new_empty(name, resource_desc.as_deref(), source.clone(), &resource_id); doc.set_embedding_model_used(generator.model_type()); - let keywords = UnstructuredParser::extract_keywords(&text_groups, 30); + let keywords = UnstructuredParser::extract_keywords(&text_groups, 50); doc.update_resource_embedding_blocking(generator, keywords)?; for grouped_text in &text_groups { @@ -225,18 +225,31 @@ impl UnstructuredParser { fn collect_texts_and_indices( text_groups: &[GroupedText], texts: &mut Vec, - indices: &mut Vec<(usize, Option)>, + indices: &mut Vec<(Vec, usize)>, max_chunk_size: u64, + path: Vec, ) { for (i, text_group) in text_groups.iter().enumerate() { println!("Processing text group at index {}", i); texts.push(text_group.text.clone()); - indices.push((i, None)); + let mut current_path = path.clone(); + current_path.push(i); + indices.push((current_path.clone(), texts.len() - 1)); for (j, sub_group) in text_group.sub_groups.iter().enumerate() { println!("Processing sub-group at index {} of text group at index {}", j, i); texts.push(sub_group.text.clone()); - indices.push((i, Some(j))); - Self::collect_texts_and_indices(&sub_group.sub_groups, texts, indices, max_chunk_size); + let mut sub_path = current_path.clone(); + sub_path.push(j); + indices.push((sub_path, texts.len() - 1)); + } + for sub_group in &text_group.sub_groups { + Self::collect_texts_and_indices( + &sub_group.sub_groups, + texts, + indices, + max_chunk_size, + current_path.clone(), + ); } } } @@ -245,18 +258,16 @@ impl UnstructuredParser { fn assign_embeddings( text_groups: &mut [GroupedText], embeddings: &mut Vec, - indices: &[(usize, Option)], + indices: &[(Vec, usize)], ) { - for (i, text_group) in text_groups.iter_mut().enumerate() { - if let Some((index, sub_index)) = indices.get(i) { - if let Some(embedding) = embeddings.get(*index) { - match sub_index { - Some(j) => text_group.sub_groups[*j].embedding = Some(embedding.clone()), - None => text_group.embedding = Some(embedding.clone()), - } + for (path, flat_index) in indices { + if let Some(embedding) = embeddings.get(*flat_index) { + let mut target = &mut text_groups[path[0]]; + for &index in &path[1..] { + target = &mut target.sub_groups[index]; } + target.embedding = Some(embedding.clone()); } - Self::assign_embeddings(&mut text_group.sub_groups, embeddings, indices); } } @@ -275,7 +286,7 @@ impl UnstructuredParser { // Collect all texts from the text groups and their subgroups let mut texts = Vec::new(); let mut indices = Vec::new(); - Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size); + Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size, vec![]); // Generate embeddings for all texts in batches let ids: Vec = vec!["".to_string(); texts.len()]; @@ -308,9 +319,14 @@ impl UnstructuredParser { } } + println!("Number of texts: {}", texts.len()); + println!("Number of indices: {}", indices.len()); + println!("Number of embeddings: {}", embeddings.len()); + // Assign the generated embeddings back to the text groups and their subgroups Self::assign_embeddings(&mut text_groups, &mut embeddings, &indices); + println!("Embeddings successfully assigned!"); Ok(text_groups) }