diff --git a/Cargo.lock b/Cargo.lock index 3e9eee18d..18724cbb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3002,6 +3002,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "tokio", ] [[package]] diff --git a/shinkai-libs/shinkai-vector-resources/Cargo.toml b/shinkai-libs/shinkai-vector-resources/Cargo.toml index d6f2b6d70..eb3a1cf06 100644 --- a/shinkai-libs/shinkai-vector-resources/Cargo.toml +++ b/shinkai-libs/shinkai-vector-resources/Cargo.toml @@ -17,6 +17,7 @@ keyphrases = "0.3.2" async-trait = "0.1.74" async-recursion = "1.0.5" # llm = { git = "https://github.com/rustformers/llm", branch="main"} +tokio = "1" [dependencies.reqwest] version = "0.11" diff --git a/shinkai-libs/shinkai-vector-resources/src/resource_errors.rs b/shinkai-libs/shinkai-vector-resources/src/resource_errors.rs index ffb6bc2c1..c5fc802c1 100644 --- a/shinkai-libs/shinkai-vector-resources/src/resource_errors.rs +++ b/shinkai-libs/shinkai-vector-resources/src/resource_errors.rs @@ -1,6 +1,7 @@ use serde_json::Error as SerdeError; use std::error::Error; use std::fmt; +use tokio::task::JoinError; use crate::vector_resource::VRPath; @@ -22,6 +23,7 @@ pub enum VectorResourceError { InvalidVRPath(VRPath), FailedParsingUnstructedAPIJSON(String), CouldNotDetectFileType(String), + TaskFailed(String), } impl fmt::Display for VectorResourceError { @@ -54,12 +56,21 @@ impl fmt::Display for VectorResourceError { VectorResourceError::CouldNotDetectFileType(ref s) => { write!(f, "Could not detect file type from file name: {}", s) } + VectorResourceError::TaskFailed(ref s) => { + write!(f, "Tokio task failed: {}", s) + } } } } impl Error for VectorResourceError {} +impl From for VectorResourceError { + fn from(error: JoinError) -> Self { + VectorResourceError::TaskFailed(error.to_string()) + } +} + impl From for VectorResourceError { fn from(err: regex::Error) -> VectorResourceError { VectorResourceError::RegexError(err) diff --git a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_api.rs b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_api.rs index 27798c532..537a0141a 100644 --- a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_api.rs +++ b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_api.rs @@ -33,7 +33,7 @@ impl UnstructuredAPI { /// Makes a blocking request to process a file in a buffer to Unstructured server, /// and then processing the returned results into a BaseVectorResource /// Note: Requires name to include the extension ie. `*.pdf` - pub fn process_file( + pub fn process_file_blocking( &self, file_buffer: Vec, generator: &dyn EmbeddingGenerator, @@ -62,7 +62,7 @@ impl UnstructuredAPI { /// Makes an async request to process a file in a buffer to Unstructured server, /// and then processing the returned results into a BaseVectorResource /// Note: Requires name to include the extension ie. `*.pdf` - pub async fn process_file_async( + pub async fn process_file( &self, file_buffer: Vec, generator: &dyn EmbeddingGenerator, @@ -74,7 +74,7 @@ impl UnstructuredAPI { ) -> Result { // Parse pdf into groups of lines + a resource_id from the hash of the data let resource_id = UnstructuredParser::generate_data_hash(&file_buffer); - let elements = self.file_request_async(file_buffer, &name).await?; + let elements = self.file_request(file_buffer, &name).await?; UnstructuredParser::process_elements_into_resource( elements, @@ -125,7 +125,7 @@ impl UnstructuredAPI { /// Makes an async request to process a file in a buffer into a list of /// UnstructuredElements - pub async fn file_request_async( + pub async fn file_request( &self, file_buffer: Vec, file_name: &str, diff --git a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs index 5830fd0ea..109ba0092 100644 --- a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs +++ b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs @@ -2,7 +2,7 @@ use super::unstructured_types::{ElementType, GroupedText, UnstructuredElement}; use crate::base_vector_resources::BaseVectorResource; use crate::data_tags::DataTag; use crate::document_resource::DocumentVectorResource; -use crate::embedding_generator::EmbeddingGenerator; +use crate::embedding_generator::{EmbeddingGenerator, RemoteEmbeddingGenerator}; use crate::embeddings::Embedding; use crate::resource_errors::VectorResourceError; use crate::source::VRSource; @@ -87,10 +87,15 @@ impl UnstructuredParser { max_chunk_size: u64, ) -> Result { // Group elements together before generating the doc - let mut text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size); - Self::generate_text_group_embeddings(&mut text_groups, generator, 5, max_chunk_size).await?; + let text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size); + let task = tokio::spawn(async move { + let generator = RemoteEmbeddingGenerator::new_default(); + Self::generate_text_group_embeddings(&text_groups, &generator, 5, max_chunk_size).await + }); + let new_text_groups = task.await??; + Self::process_new_doc_resource( - text_groups, + new_text_groups, generator, &name, desc, @@ -256,18 +261,22 @@ impl UnstructuredParser { } /// Recursively goes through all of the text groups and batch generates embeddings - /// for all of them. Of note this is using a mutable reference, thus the text_groups are mutated - /// with the new embeddings just set directly. + /// for all of them. pub async fn generate_text_group_embeddings( - text_groups: &mut Vec, + text_groups: &Vec, generator: &dyn EmbeddingGenerator, max_batch_size: u64, max_chunk_size: u64, - ) -> Result<(), VectorResourceError> { + ) -> Result, VectorResourceError> { + // Clone the input text_groups + let mut text_groups = text_groups.clone(); + // Collect all texts from the text groups and their subgroups let mut texts = Vec::new(); let mut indices = Vec::new(); - Self::collect_texts_and_indices(text_groups, &mut texts, &mut indices, max_chunk_size); + Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size); + + let generator = RemoteEmbeddingGenerator::new_default(); // Generate embeddings for all texts in batches let ids: Vec = vec!["".to_string(); texts.len()]; @@ -284,10 +293,10 @@ impl UnstructuredParser { println!("All embeddings generated. Total: {}", embeddings.len()); // Assign the generated embeddings back to the text groups and their subgroups - Self::assign_embeddings(text_groups, &mut embeddings, &indices); + Self::assign_embeddings(&mut text_groups, &mut embeddings, &indices); println!("All embeddings set."); - Ok(()) + Ok(text_groups) } /// Helper method for processing a grouped text for process_new_doc_resource diff --git a/shinkai-libs/shinkai-vector-resources/tests/unstructured_tests.rs b/shinkai-libs/shinkai-vector-resources/tests/unstructured_tests.rs index fd3f4b0e1..514a887c9 100644 --- a/shinkai-libs/shinkai-vector-resources/tests/unstructured_tests.rs +++ b/shinkai-libs/shinkai-vector-resources/tests/unstructured_tests.rs @@ -70,7 +70,7 @@ fn test_unstructured_parse_pdf_vector_resource() { let api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None); let resource = api - .process_file( + .process_file_blocking( file_buffer, &generator, file_name.to_string(), diff --git a/src/agent/file_parsing.rs b/src/agent/file_parsing.rs index c8b6c48c6..690556d41 100644 --- a/src/agent/file_parsing.rs +++ b/src/agent/file_parsing.rs @@ -49,9 +49,9 @@ impl JobManager { ) -> Result { // Parse file into needed data let resource_id = UnstructuredParser::generate_data_hash(&file_buffer); - let api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None); + let unstructured_api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None); let source = VRSource::from_file(&name, &file_buffer)?; - let elements = api.file_request_async(file_buffer, &name).await?; + let elements = unstructured_api.file_request(file_buffer, &name).await?; // Automatically generate description if none is provided let mut desc = desc;