Skip to content

Commit

Permalink
Added new thread spawning
Browse files Browse the repository at this point in the history
  • Loading branch information
robkorn committed Oct 17, 2023
1 parent 78db4d9 commit 4f836d6
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions shinkai-libs/shinkai-vector-resources/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ keyphrases = "0.3.2"
async-trait = "0.1.74"
async-recursion = "1.0.5"
# llm = { git = "https://github.com/rustformers/llm", branch="main"}
tokio = "1"

[dependencies.reqwest]
version = "0.11"
Expand Down
11 changes: 11 additions & 0 deletions shinkai-libs/shinkai-vector-resources/src/resource_errors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde_json::Error as SerdeError;
use std::error::Error;
use std::fmt;
use tokio::task::JoinError;

use crate::vector_resource::VRPath;

Expand All @@ -22,6 +23,7 @@ pub enum VectorResourceError {
InvalidVRPath(VRPath),
FailedParsingUnstructedAPIJSON(String),
CouldNotDetectFileType(String),
TaskFailed(String),
}

impl fmt::Display for VectorResourceError {
Expand Down Expand Up @@ -54,12 +56,21 @@ impl fmt::Display for VectorResourceError {
VectorResourceError::CouldNotDetectFileType(ref s) => {
write!(f, "Could not detect file type from file name: {}", s)
}
VectorResourceError::TaskFailed(ref s) => {
write!(f, "Tokio task failed: {}", s)
}
}
}
}

impl Error for VectorResourceError {}

impl From<JoinError> for VectorResourceError {
fn from(error: JoinError) -> Self {
VectorResourceError::TaskFailed(error.to_string())
}
}

impl From<regex::Error> for VectorResourceError {
fn from(err: regex::Error) -> VectorResourceError {
VectorResourceError::RegexError(err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl UnstructuredAPI {
/// Makes a blocking request to process a file in a buffer to Unstructured server,
/// and then processing the returned results into a BaseVectorResource
/// Note: Requires name to include the extension ie. `*.pdf`
pub fn process_file(
pub fn process_file_blocking(
&self,
file_buffer: Vec<u8>,
generator: &dyn EmbeddingGenerator,
Expand Down Expand Up @@ -62,7 +62,7 @@ impl UnstructuredAPI {
/// Makes an async request to process a file in a buffer to Unstructured server,
/// and then processing the returned results into a BaseVectorResource
/// Note: Requires name to include the extension ie. `*.pdf`
pub async fn process_file_async(
pub async fn process_file(
&self,
file_buffer: Vec<u8>,
generator: &dyn EmbeddingGenerator,
Expand All @@ -74,7 +74,7 @@ impl UnstructuredAPI {
) -> Result<BaseVectorResource, VectorResourceError> {
// Parse pdf into groups of lines + a resource_id from the hash of the data
let resource_id = UnstructuredParser::generate_data_hash(&file_buffer);
let elements = self.file_request_async(file_buffer, &name).await?;
let elements = self.file_request(file_buffer, &name).await?;

UnstructuredParser::process_elements_into_resource(
elements,
Expand Down Expand Up @@ -125,7 +125,7 @@ impl UnstructuredAPI {

/// Makes an async request to process a file in a buffer into a list of
/// UnstructuredElements
pub async fn file_request_async(
pub async fn file_request(
&self,
file_buffer: Vec<u8>,
file_name: &str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::unstructured_types::{ElementType, GroupedText, UnstructuredElement};
use crate::base_vector_resources::BaseVectorResource;
use crate::data_tags::DataTag;
use crate::document_resource::DocumentVectorResource;
use crate::embedding_generator::EmbeddingGenerator;
use crate::embedding_generator::{EmbeddingGenerator, RemoteEmbeddingGenerator};
use crate::embeddings::Embedding;
use crate::resource_errors::VectorResourceError;
use crate::source::VRSource;
Expand Down Expand Up @@ -87,10 +87,15 @@ impl UnstructuredParser {
max_chunk_size: u64,
) -> Result<BaseVectorResource, VectorResourceError> {
// Group elements together before generating the doc
let mut text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size);
Self::generate_text_group_embeddings(&mut text_groups, generator, 5, max_chunk_size).await?;
let text_groups = UnstructuredParser::hierarchical_group_elements_text(&elements, max_chunk_size);
let task = tokio::spawn(async move {
let generator = RemoteEmbeddingGenerator::new_default();
Self::generate_text_group_embeddings(&text_groups, &generator, 5, max_chunk_size).await
});
let new_text_groups = task.await??;

Self::process_new_doc_resource(
text_groups,
new_text_groups,
generator,
&name,
desc,
Expand Down Expand Up @@ -256,18 +261,22 @@ impl UnstructuredParser {
}

/// Recursively goes through all of the text groups and batch generates embeddings
/// for all of them. Of note this is using a mutable reference, thus the text_groups are mutated
/// with the new embeddings just set directly.
/// for all of them.
pub async fn generate_text_group_embeddings(
text_groups: &mut Vec<GroupedText>,
text_groups: &Vec<GroupedText>,
generator: &dyn EmbeddingGenerator,
max_batch_size: u64,
max_chunk_size: u64,
) -> Result<(), VectorResourceError> {
) -> Result<Vec<GroupedText>, VectorResourceError> {
// Clone the input text_groups
let mut text_groups = text_groups.clone();

// Collect all texts from the text groups and their subgroups
let mut texts = Vec::new();
let mut indices = Vec::new();
Self::collect_texts_and_indices(text_groups, &mut texts, &mut indices, max_chunk_size);
Self::collect_texts_and_indices(&text_groups, &mut texts, &mut indices, max_chunk_size);

let generator = RemoteEmbeddingGenerator::new_default();

// Generate embeddings for all texts in batches
let ids: Vec<String> = vec!["".to_string(); texts.len()];
Expand All @@ -284,10 +293,10 @@ impl UnstructuredParser {

println!("All embeddings generated. Total: {}", embeddings.len());
// Assign the generated embeddings back to the text groups and their subgroups
Self::assign_embeddings(text_groups, &mut embeddings, &indices);
Self::assign_embeddings(&mut text_groups, &mut embeddings, &indices);

println!("All embeddings set.");
Ok(())
Ok(text_groups)
}

/// Helper method for processing a grouped text for process_new_doc_resource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn test_unstructured_parse_pdf_vector_resource() {
let api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None);

let resource = api
.process_file(
.process_file_blocking(
file_buffer,
&generator,
file_name.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/agent/file_parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ impl JobManager {
) -> Result<BaseVectorResource, AgentError> {
// Parse file into needed data
let resource_id = UnstructuredParser::generate_data_hash(&file_buffer);
let api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None);
let unstructured_api = UnstructuredAPI::new(UNSTRUCTURED_API_URL.to_string(), None);
let source = VRSource::from_file(&name, &file_buffer)?;
let elements = api.file_request_async(file_buffer, &name).await?;
let elements = unstructured_api.file_request(file_buffer, &name).await?;

// Automatically generate description if none is provided
let mut desc = desc;
Expand Down

0 comments on commit 4f836d6

Please sign in to comment.