diff --git a/Cargo.lock b/Cargo.lock index 0fd863cbc..13ab0d406 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -779,16 +779,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" -dependencies = [ - "cfg-if", - "crossbeam-utils", -] - [[package]] name = "crossbeam-deque" version = "0.8.3" @@ -2576,9 +2566,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -2586,14 +2576,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] diff --git a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs index 172cb1379..d62f72701 100644 --- a/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs +++ b/shinkai-libs/shinkai-message-primitives/src/shinkai_message/shinkai_message_schemas.rs @@ -78,7 +78,7 @@ pub struct JobCreationInfo { pub scope: JobScope, } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct JobMessage { // TODO: scope div modifications? pub job_id: String, diff --git a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs index 1cb1d45bc..b8b43d3a7 100644 --- a/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs +++ b/shinkai-libs/shinkai-vector-resources/src/unstructured/unstructured_parser.rs @@ -100,10 +100,13 @@ impl UnstructuredParser { // Extract keywords from the elements let keywords = UnstructuredParser::extract_keywords(&text_groups, 50); + eprintln!("Keywords: {:?}", keywords); // Set the resource embedding, using the keywords + name + desc + source doc.update_resource_embedding(generator, keywords)?; + eprintln!("Processing doc: `{}`", doc.name()); + // Add each text group as either Vector Resource DataChunks, // or data-holding DataChunks depending on if each has any sub-groups for grouped_text in &text_groups { @@ -151,6 +154,7 @@ impl UnstructuredParser { } } + eprintln!("Finished processing doc: `{}`", doc.name()); Ok(BaseVectorResource::Document(doc)) } diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 15581fcf0..f4964ee5e 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -17,8 +17,6 @@ use tokio::sync::{mpsc, Mutex}; pub struct Agent { pub id: String, pub full_identity_name: ShinkaiName, - pub job_manager_sender: mpsc::Sender<(Vec, String)>, - pub agent_receiver: Arc>>, pub client: Client, pub perform_locally: bool, // Todo: Remove as not used anymore pub external_url: Option, // external API URL @@ -33,7 +31,6 @@ impl Agent { pub fn new( id: String, full_identity_name: ShinkaiName, - job_manager_sender: mpsc::Sender<(Vec, String)>, perform_locally: bool, external_url: Option, api_key: Option, @@ -43,13 +40,9 @@ impl Agent { allowed_message_senders: Vec, ) -> Self { let client = Client::new(); - let (_, agent_receiver) = mpsc::channel(1); // TODO: I think we can remove this altogether - let agent_receiver = Arc::new(Mutex::new(agent_receiver)); // wrap the receiver Self { id, full_identity_name, - job_manager_sender, - agent_receiver, client, perform_locally, external_url, @@ -101,12 +94,10 @@ impl Agent { impl Agent { pub fn from_serialized_agent( serialized_agent: SerializedAgent, - sender: mpsc::Sender<(Vec, String)>, ) -> Self { Self::new( serialized_agent.id, serialized_agent.full_identity_name, - sender, serialized_agent.perform_locally, serialized_agent.external_url, serialized_agent.api_key, diff --git a/src/agent/error.rs b/src/agent/error.rs index 4f2275976..f08f04117 100644 --- a/src/agent/error.rs +++ b/src/agent/error.rs @@ -33,6 +33,7 @@ pub enum AgentError { TaskJoinError(String), InferenceRecursionLimitReached(String), TokenizationError(String), + JobDequeueFailed(String) } impl fmt::Display for AgentError { @@ -78,6 +79,7 @@ impl fmt::Display for AgentError { AgentError::TaskJoinError(s) => write!(f, "Task join error: {}", s), AgentError::InferenceRecursionLimitReached(s) => write!(f, "Inferencing the LLM has reached too many iterations of recursion with no progess, and thus has been stopped for this job_task: {}", s), AgentError::TokenizationError(s) => write!(f, "Tokenization error: {}", s), + AgentError::JobDequeueFailed(s) => write!(f, "Job dequeue failed: {}", s), } } diff --git a/src/agent/execution/chains/inference_chain_router.rs b/src/agent/execution/chains/inference_chain_router.rs index 660fe8425..ba6f575a2 100644 --- a/src/agent/execution/chains/inference_chain_router.rs +++ b/src/agent/execution/chains/inference_chain_router.rs @@ -1,7 +1,9 @@ use crate::agent::agent::Agent; use crate::agent::error::AgentError; use crate::agent::job::Job; -use crate::agent::job_manager::AgentManager; +use crate::agent::job_manager::JobManager; +use crate::db::ShinkaiDB; +use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobMessage; use shinkai_vector_resources::embedding_generator::EmbeddingGenerator; @@ -15,13 +17,13 @@ pub enum InferenceChain { CodingChain, } -impl AgentManager { +impl JobManager { /// Chooses an inference chain based on the job message (using the agent's LLM) /// and then starts using the chosen chain. /// Returns the final String result from the inferencing, and a new execution context. pub async fn inference_chain_router( - &self, - agent_found: Option>>, + db: Arc>, + agent_found: Option, full_job: Job, job_message: JobMessage, prev_execution_context: HashMap, @@ -37,20 +39,20 @@ impl AgentManager { match chosen_chain { InferenceChain::QAChain => { if let Some(agent) = agent_found { - inference_response_content = self - .start_qa_inference_chain( - full_job, - job_message.content.clone(), - agent, - prev_execution_context, - generator, - user_profile, - None, - None, - 0, - 5, - ) - .await?; + inference_response_content = JobManager::start_qa_inference_chain( + db, + full_job, + job_message.content.clone(), + agent, + prev_execution_context, + generator, + user_profile, + None, + None, + 0, + 5, + ) + .await?; new_execution_context .insert("previous_step_response".to_string(), inference_response_content.clone()); } else { diff --git a/src/agent/execution/chains/qa_inference_chain.rs b/src/agent/execution/chains/qa_inference_chain.rs index d72140a9d..2230ac637 100644 --- a/src/agent/execution/chains/qa_inference_chain.rs +++ b/src/agent/execution/chains/qa_inference_chain.rs @@ -3,23 +3,26 @@ use crate::agent::error::AgentError; use crate::agent::execution::job_prompts::JobPromptGenerator; use crate::agent::file_parsing::ParsingHelper; use crate::agent::job::{Job, JobId, JobLike}; -use crate::agent::job_manager::AgentManager; +use crate::agent::job_manager::JobManager; +use crate::db::ShinkaiDB; +use crate::resources::bert_cpp::BertCPPProcess; use async_recursion::async_recursion; +use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_vector_resources::embedding_generator::EmbeddingGenerator; use std::result::Result::Ok; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; -impl AgentManager { +impl JobManager { /// An inference chain for question-answer job tasks which vector searches the Vector Resources /// in the JobScope to find relevant content for the LLM to use at each step. #[async_recursion] pub async fn start_qa_inference_chain( - &self, + db: Arc>, full_job: Job, job_task: String, - agent: Arc>, + agent: SerializedAgent, execution_context: HashMap, generator: &dyn EmbeddingGenerator, user_profile: Option, @@ -33,9 +36,9 @@ impl AgentManager { // Use search_text if available (on recursion), otherwise use job_task to generate the query (on first iteration) let query_text = search_text.clone().unwrap_or(job_task.clone()); let query = generator.generate_embedding_default(&query_text).unwrap(); - let ret_data_chunks = self - .job_scope_vector_search(full_job.scope(), query, 20, &user_profile.clone().unwrap(), true) - .await?; + let ret_data_chunks = + JobManager::job_scope_vector_search(db.clone(), full_job.scope(), query, 20, &user_profile.clone().unwrap(), true) + .await?; // Use the default prompt if not reached final iteration count, else use final prompt let filled_prompt = if iteration_count < max_iterations { @@ -62,8 +65,8 @@ impl AgentManager { // Inference the agent's LLM with the prompt. If it has an answer, the chain // is finished and so just return the answer response as a cleaned String - let response_json = self.inference_agent(agent.clone(), filled_prompt).await?; - if let Ok(answer_str) = self.extract_inference_json_response(response_json.clone(), "answer") { + let response_json = JobManager::inference_agent(agent.clone(), filled_prompt).await?; + if let Ok(answer_str) = JobManager::extract_inference_json_response(response_json.clone(), "answer") { let cleaned_answer = ParsingHelper::ending_stripper(&answer_str); println!("QA Chain Final Answer: {:?}", cleaned_answer); return Ok(cleaned_answer); @@ -75,17 +78,17 @@ impl AgentManager { // If not an answer, then the LLM must respond with a search/summary, so we parse them // to use for the next recursive call - let (mut new_search_text, summary) = match self.extract_inference_json_response(response_json.clone(), "search") - { - Ok(search_str) => { - let summary_str = response_json - .get("summary") - .and_then(|s| s.as_str()) - .map(|s| ParsingHelper::ending_stripper(s)); - (search_str, summary_str) - } - Err(_) => return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())), - }; + let (mut new_search_text, summary) = + match JobManager::extract_inference_json_response(response_json.clone(), "search") { + Ok(search_str) => { + let summary_str = response_json + .get("summary") + .and_then(|s| s.as_str()) + .map(|s| ParsingHelper::ending_stripper(s)); + (search_str, summary_str) + } + Err(_) => return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())), + }; // If the new search text is the same as the previous one, prompt the agent for a new search term if Some(new_search_text.clone()) == search_text { @@ -93,8 +96,8 @@ impl AgentManager { new_search_text.clone(), summary.clone().unwrap_or_default(), ); - let response_json = self.inference_agent(agent.clone(), retry_prompt).await?; - match self.extract_inference_json_response(response_json, "search") { + let response_json = JobManager::inference_agent(agent.clone(), retry_prompt).await?; + match JobManager::extract_inference_json_response(response_json, "search") { Ok(search_str) => { println!("QA Chain New Search Retry Term: {:?}", search_str); new_search_text = search_str; @@ -104,7 +107,8 @@ impl AgentManager { } // Recurse with the new search/summary text and increment iteration_count - self.start_qa_inference_chain( + JobManager::start_qa_inference_chain( + db, full_job, job_task.to_string(), agent, diff --git a/src/agent/execution/chains/tool_execution_chain.rs b/src/agent/execution/chains/tool_execution_chain.rs index 66e6f4daf..362d89aa7 100644 --- a/src/agent/execution/chains/tool_execution_chain.rs +++ b/src/agent/execution/chains/tool_execution_chain.rs @@ -1,6 +1,6 @@ -use crate::agent::job_manager::AgentManager; +use crate::agent::job_manager::JobManager; -impl AgentManager { +impl JobManager { pub fn start_tool_execution_inference_chain(&self) -> () { self.analysis_phase(); diff --git a/src/agent/execution/job_execution_core.rs b/src/agent/execution/job_execution_core.rs index 43567202f..f856b594e 100644 --- a/src/agent/execution/job_execution_core.rs +++ b/src/agent/execution/job_execution_core.rs @@ -3,11 +3,15 @@ use crate::agent::agent::Agent; use crate::agent::error::AgentError; use crate::agent::file_parsing::ParsingHelper; use crate::agent::job::{Job, JobLike}; -use crate::agent::job_manager::AgentManager; +use crate::agent::job_manager::JobManager; +use crate::agent::queue::job_queue_manager::JobForProcessing; use crate::db::ShinkaiDB; use crate::resources::bert_cpp::BertCPPProcess; +use ed25519_dalek::SecretKey as SignatureStaticKey; use serde_json::Value as JsonValue; +use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::shinkai_utils::job_scope::{DBScopeEntry, LocalScopeEntry, ScopeEntry}; +use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_message_primitives::{ schemas::shinkai_name::ShinkaiName, shinkai_message::{shinkai_message::ShinkaiMessage, shinkai_message_schemas::JobMessage}, @@ -22,77 +26,51 @@ use std::time::Instant; use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; -impl AgentManager { +impl JobManager { /// Processes a job message which will trigger a job step - pub async fn process_job_step( - &mut self, - message: ShinkaiMessage, - job_message: JobMessage, + pub async fn process_job_message_queued( + job_message: JobForProcessing, + db: Arc>, + identity_secret_key: SignatureStaticKey, ) -> Result { - if let Some(job) = self.jobs.lock().await.get(&job_message.job_id) { - // Basic setup - let job = job.clone(); - let job_id = job.job_id().to_string(); - let mut shinkai_db = self.db.lock().await; - shinkai_db.add_message_to_job_inbox(&job_message.job_id.clone(), &message)?; - println!("process_job_step> job_message: {:?}", job_message); - - // Verify identity/profile match - let sender_subidentity_result = - ShinkaiName::from_shinkai_message_using_sender_subidentity(&message.clone()); - let sender_subidentity = match sender_subidentity_result { - Ok(subidentity) => subidentity, - Err(e) => return Err(AgentError::InvalidSubidentity(e)), - }; - let profile_result = sender_subidentity.extract_profile(); - let profile = match profile_result { - Ok(profile) => profile, - Err(e) => return Err(AgentError::InvalidProfileSubidentity(e.to_string())), - }; - - // TODO: Implement unprocessed messages/queuing logic - // If current unprocessed message count >= 1, then simply add unprocessed message and return success. - // However if unprocessed message count == 0, then: - // 0. You add the unprocessed message to the list in the DB - // 1. Start a while loop where every time you fetch the unprocessed messages for the job from the DB and check if there's >= 1 - // 2. You read the first/front unprocessed message (not pop from the back) - // 3. You start analysis phase to generate the execution plan. - // 4. You then take the execution plan and process the execution phase. - // 5. Once execution phase succeeds, you then delete the message from the unprocessed list in the DB - // and take the result and append it both to the Job inbox and step history - // 6. As we're in a while loop, go back to 1, meaning any new unprocessed messages added while the step was happening are now processed sequentially - - // let current_unprocessed_message_count = ... - shinkai_db.add_to_unprocessed_messages_list(job.job_id().to_string(), job_message.content.clone())?; - - std::mem::drop(shinkai_db); // require to avoid deadlock - - // Fetch data we need to execute job step - let (mut full_job, agent_found, profile_name, user_profile) = - self.fetch_relevant_job_data(job.job_id()).await?; + let job_id = job_message.job_message.job_id.clone(); + // Fetch data we need to execute job step + let (mut full_job, agent_found, profile_name, user_profile) = + JobManager::fetch_relevant_job_data(&job_message.job_message.job_id, db.clone()).await?; - // Processes any files which were sent with the job message - self.process_job_message_files(&job_message, agent_found.clone(), &mut full_job, profile, false) - .await?; + // Processes any files which were sent with the job message + JobManager::process_job_message_files( + db.clone(), + &job_message.job_message, + agent_found.clone(), + &mut full_job, + job_message.profile, + false, + ) + .await?; - // TODO(Nico): move this to a parallel thread that runs in the background - let _ = self - .process_inference_chain(job_message, full_job, agent_found.clone(), profile_name, user_profile) - .await?; + let _ = JobManager::process_inference_chain( + db, + identity_secret_key, + job_message.job_message, + full_job, + agent_found.clone(), + profile_name, + user_profile, + ) + .await?; - return Ok(job_id.clone()); - } else { - return Err(AgentError::JobNotFound); - } + return Ok(job_id.clone()); } /// Processes the provided message & job data, routes them to a specific inference chain, /// and then parses + saves the output result to the DB. pub async fn process_inference_chain( - &self, + db: Arc>, + identity_secret_key: SignatureStaticKey, job_message: JobMessage, full_job: Job, - agent_found: Option>>, + agent_found: Option, profile_name: String, user_profile: Option, ) -> Result<(), AgentError> { @@ -101,26 +79,30 @@ impl AgentManager { // Setup initial data to get ready to call a specific inference chain let prev_execution_context = full_job.execution_context.clone(); - let bert_process = BertCPPProcess::start(); // Gets killed if out of scope + // let bert_process = BertCPPProcess::start(); // Gets killed if out of scope let generator = RemoteEmbeddingGenerator::new_default(); let start = Instant::now(); // Call the inference chain router to choose which chain to use, and call it - let (inference_response_content, new_execution_context) = self - .inference_chain_router( - agent_found, - full_job, - job_message.clone(), - prev_execution_context, - &generator, - user_profile, - ) - .await?; + let (inference_response_content, new_execution_context) = JobManager::inference_chain_router( + db.clone(), + agent_found, + full_job, + job_message.clone(), + prev_execution_context, + &generator, + user_profile, + ) + .await?; let duration = start.elapsed(); - println!("Time elapsed for inference chain processing is: {:?}", duration); + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + &format!("Time elapsed for inference chain processing is: {:?}", duration), + ); // Prepare data to save inference response to the DB - let identity_secret_key_clone = clone_signature_secret_key(&self.identity_secret_key); + let identity_secret_key_clone = clone_signature_secret_key(&identity_secret_key); let shinkai_message = ShinkaiMessageBuilder::job_message_from_agent( job_id.to_string(), inference_response_content.to_string(), @@ -129,44 +111,48 @@ impl AgentManager { profile_name.clone(), ) .unwrap(); + + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + format!("process_inference_chain> shinkai_message: {:?}", shinkai_message).as_str(), + ); + // Save response data to DB - let mut shinkai_db = self.db.lock().await; + let mut shinkai_db = db.lock().await; shinkai_db.add_step_history(job_message.job_id.clone(), job_message.content)?; shinkai_db.add_step_history(job_message.job_id.clone(), inference_response_content.to_string())?; shinkai_db.add_message_to_job_inbox(&job_message.job_id.clone(), &shinkai_message)?; shinkai_db.set_job_execution_context(&job_message.job_id.clone(), new_execution_context)?; - std::mem::drop(bert_process); - Ok(()) } /// Processes the files sent together with the current job_message into Vector Resources, /// and saves them either into the local job scope, or the DB depending on `save_to_db_directly`. pub async fn process_job_message_files( - &self, + db: Arc>, job_message: &JobMessage, - agent_found: Option>>, + agent_found: Option, full_job: &mut Job, profile: ShinkaiName, save_to_db_directly: bool, ) -> Result<(), AgentError> { if !job_message.files_inbox.is_empty() { - println!( - "process_job_message> processing files_map: ... files: {}", - job_message.files_inbox.len() + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + format!("Processing files_map: ... files: {}", job_message.files_inbox.len()).as_str(), ); // TODO: later we should able to grab errors and return them to the user - let new_scope_entries = self - .process_files_inbox( - self.db.clone(), - agent_found, - job_message.files_inbox.clone(), - profile, - save_to_db_directly, - ) - .await?; - eprintln!(">>> new_scope_entries: {:?}", new_scope_entries.keys()); + let new_scope_entries = JobManager::process_files_inbox( + db.clone(), + agent_found, + job_message.files_inbox.clone(), + profile, + save_to_db_directly, + ) + .await?; for (_, value) in new_scope_entries { match value { @@ -187,13 +173,12 @@ impl AgentManager { } } { - let mut shinkai_db = self.db.lock().await; + let mut shinkai_db = db.lock().await; shinkai_db.update_job_scope(full_job.job_id().to_string(), full_job.scope.clone())?; - eprintln!(">>> job_scope updated"); } } else { // TODO: move this somewhere else - let mut shinkai_db = self.db.lock().await; + let mut shinkai_db = db.lock().await; shinkai_db.init_profile_resource_router(&profile)?; std::mem::drop(shinkai_db); // required to avoid deadlock } @@ -205,9 +190,8 @@ impl AgentManager { /// If save_to_db_directly == true, the files will save to the DB and be returned as `DBScopeEntry`s. /// Else, the files will be returned as `LocalScopeEntry`s and thus held inside. pub async fn process_files_inbox( - &self, db: Arc>, - agent: Option>>, + agent: Option, files_inbox: String, profile: ShinkaiName, save_to_db_directly: bool, @@ -218,7 +202,7 @@ impl AgentManager { None => return Err(AgentError::AgentNotFound), }; - let _bert_process = BertCPPProcess::start(); // Gets killed if out of scope + // let _bert_process = BertCPPProcess::start(); // Gets killed if out of scope let mut shinkai_db = db.lock().await; let files_result = shinkai_db.get_all_files_from_inbox(files_inbox.clone()); // Check if there was an error getting the files @@ -232,18 +216,22 @@ impl AgentManager { // Start processing the files for (filename, content) in files.into_iter() { - eprintln!("Processing file: {}", filename); - let resource = self - .parse_file_into_resource( - content.clone(), - &*generator, - filename.clone(), - None, - &vec![], - agent.clone(), - 400, - ) - .await?; + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + &format!("Processing file: {}", filename), + ); + let resource = JobManager::parse_file_into_resource( + db.clone(), + content.clone(), + &*generator, + filename.clone(), + None, + &vec![], + agent.clone(), + 400, + ) + .await?; // Now create Local/DBScopeEntry depending on setting if save_to_db_directly { diff --git a/src/agent/execution/job_execution_helpers.rs b/src/agent/execution/job_execution_helpers.rs index 449747c13..6f9ac4b4e 100644 --- a/src/agent/execution/job_execution_helpers.rs +++ b/src/agent/execution/job_execution_helpers.rs @@ -1,9 +1,12 @@ -use crate::agent::agent::Agent; use crate::agent::error::AgentError; use crate::agent::job::Job; -use crate::agent::job_manager::AgentManager; +use crate::agent::{agent::Agent, job_manager::JobManager}; +use crate::db::db_errors::ShinkaiDBError; +use crate::db::ShinkaiDB; use serde_json::Value as JsonValue; +use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; +use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_vector_resources::source::{SourceFileType, VRSource}; use std::result::Result::Ok; use std::sync::Arc; @@ -11,10 +14,10 @@ use tokio::sync::Mutex; use super::job_prompts::{JobPromptGenerator, Prompt}; -impl AgentManager { +impl JobManager { /// Extracts a String using the provided key in the JSON response /// Errors if the key is not present. - pub fn extract_inference_json_response(&self, response_json: JsonValue, key: &str) -> Result { + pub fn extract_inference_json_response(response_json: JsonValue, key: &str) -> Result { if let Some(value) = response_json.get(key) { let value_str = value .as_str() @@ -29,48 +32,45 @@ impl AgentManager { /// a valid JSON object/retrying if it isn't, and finally attempt to extract the provided key /// from the JSON object. Errors if the key is not found. pub async fn inference_agent_and_extract( - &self, - agent: Arc>, + agent: SerializedAgent, filled_prompt: Prompt, key: &str, ) -> Result { - let response_json = self.inference_agent(agent.clone(), filled_prompt).await?; - self.extract_inference_json_response(response_json, key) + let response_json = JobManager::inference_agent(agent.clone(), filled_prompt).await?; + JobManager::extract_inference_json_response(response_json, key) } /// Inferences the Agent's LLM with the given prompt. Automatically validates the response is /// a valid JSON object, and if it isn't re-inferences to ensure that it is returned as one. - pub async fn inference_agent( - &self, - agent: Arc>, - filled_prompt: Prompt, - ) -> Result { + pub async fn inference_agent(agent: SerializedAgent, filled_prompt: Prompt) -> Result { let agent_cloned = agent.clone(); let response = tokio::spawn(async move { - let mut agent = agent_cloned.lock().await; + let mut agent = Agent::from_serialized_agent(agent_cloned); agent.inference(filled_prompt).await }) .await?; - println!("inference_agent> response: {:?}", response); + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + format!("inference_agent> response: {:?}", response).as_str(), + ); // Validates that the response is a proper JSON object, else inferences again to get the // LLM to parse the previous response into proper JSON - self.extract_json_value_from_inference_response(response, agent.clone()) - .await + JobManager::extract_json_value_from_inference_response(response, agent.clone()).await } /// Attempts to extract the JsonValue out of the LLM's response. If it is not proper JSON /// then inferences the LLM again asking it to take its previous answer and make sure it responds with a proper JSON object. async fn extract_json_value_from_inference_response( - &self, response: Result, - agent: Arc>, + agent: SerializedAgent, ) -> Result { match response { Ok(json) => Ok(json), Err(AgentError::FailedExtractingJSONObjectFromResponse(text)) => { eprintln!("Retrying inference with new prompt"); - match self.json_not_found_retry(agent.clone(), text.clone()).await { + match JobManager::json_not_found_retry(agent.clone(), text.clone()).await { Ok(json) => Ok(json), Err(e) => Err(e), } @@ -81,9 +81,9 @@ impl AgentManager { /// Inferences the LLM again asking it to take its previous answer and make sure it responds with a proper JSON object /// that we can parse. - async fn json_not_found_retry(&self, agent: Arc>, text: String) -> Result { + async fn json_not_found_retry(agent: SerializedAgent, text: String) -> Result { let response = tokio::spawn(async move { - let mut agent = agent.lock().await; + let mut agent = Agent::from_serialized_agent(agent); let prompt = JobPromptGenerator::basic_json_retry_response_prompt(text); agent.inference(prompt).await }) @@ -93,27 +93,32 @@ impl AgentManager { /// Fetches boilerplate/relevant data required for a job to process a step pub async fn fetch_relevant_job_data( - &self, job_id: &str, - ) -> Result<(Job, Option>>, String, Option), AgentError> { + db: Arc>, + ) -> Result<(Job, Option, String, Option), AgentError> { // Fetch the job - let full_job = { self.db.lock().await.get_job(job_id)? }; + let full_job = { db.lock().await.get_job(job_id)? }; // Acquire Agent let agent_id = full_job.parent_agent_id.clone(); let mut agent_found = None; let mut profile_name = String::new(); let mut user_profile: Option = None; - for agent in &self.agents { - let locked_agent = agent.lock().await; - if locked_agent.id == agent_id { + let agents = JobManager::get_all_agents(db).await.unwrap_or(vec![]); + for agent in agents { + if agent.id == agent_id { agent_found = Some(agent.clone()); - profile_name = locked_agent.full_identity_name.full_name.clone(); - user_profile = Some(locked_agent.full_identity_name.extract_profile().unwrap()); + profile_name = agent.full_identity_name.full_name.clone(); + user_profile = Some(agent.full_identity_name.extract_profile().unwrap()); break; } } Ok((full_job, agent_found, profile_name, user_profile)) } + + pub async fn get_all_agents(db: Arc>) -> Result, ShinkaiDBError> { + let db = db.lock().await; + db.get_all_agents() + } } diff --git a/src/agent/execution/job_vector_search.rs b/src/agent/execution/job_vector_search.rs index 49dc9ae2b..9c13c9820 100644 --- a/src/agent/execution/job_vector_search.rs +++ b/src/agent/execution/job_vector_search.rs @@ -1,17 +1,20 @@ -use crate::agent::job_manager::AgentManager; +use crate::agent::job_manager::JobManager; +use crate::db::ShinkaiDB; use crate::db::db_errors::ShinkaiDBError; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_utils::job_scope::JobScope; use shinkai_vector_resources::base_vector_resources::BaseVectorResource; use shinkai_vector_resources::embeddings::Embedding; use shinkai_vector_resources::vector_resource_types::{DataChunk, RetrievedDataChunk, VectorResourcePointer}; +use tokio::sync::Mutex; use std::result::Result::Ok; +use std::sync::Arc; -impl AgentManager { +impl JobManager { /// Helper method which fetches all local & DB-held Vector Resources specified in the given JobScope /// and returns all of them in a single list ready to be used. pub async fn fetch_job_scope_resources( - &self, + db: Arc>, job_scope: &JobScope, profile: &ShinkaiName, ) -> Result, ShinkaiDBError> { @@ -23,7 +26,7 @@ impl AgentManager { } // Fetch DB resources and add them to the list - let db = self.db.lock().await; + let db = db.lock().await; for db_entry in &job_scope.database { let resource = db.get_resource_by_pointer(&db_entry.resource_pointer, profile)?; resources.push(resource); @@ -38,14 +41,14 @@ impl AgentManager { /// If include_description is true then adds the description of the Vector Resource as an auto-included /// RetrievedDataChunk at the front of the returned list. pub async fn job_scope_vector_search( - &self, + db: Arc>, job_scope: &JobScope, query: Embedding, num_of_results: u64, profile: &ShinkaiName, include_description: bool, ) -> Result, ShinkaiDBError> { - let resources = self.fetch_job_scope_resources(job_scope, profile).await?; + let resources = JobManager::fetch_job_scope_resources(db, job_scope, profile).await?; println!("Num of resources fetched: {}", resources.len()); // Perform vector search on all resources @@ -59,8 +62,7 @@ impl AgentManager { // Sort the retrieved chunks by score before returning let sorted_retrieved_chunks = RetrievedDataChunk::sort_by_score(&retrieved_chunks, num_of_results); - let updated_chunks = self - .include_description_retrieved_chunk(include_description, sorted_retrieved_chunks, &resources) + let updated_chunks = JobManager::include_description_retrieved_chunk(include_description, sorted_retrieved_chunks, &resources) .await; Ok(updated_chunks) @@ -70,7 +72,7 @@ impl AgentManager { /// If include_description is true then adds the description of the Vector Resource as an auto-included /// RetrievedDataChunk at the front of the returned list. pub async fn job_scope_syntactic_vector_search( - &self, + db: Arc>, job_scope: &JobScope, query: Embedding, num_of_results: u64, @@ -78,7 +80,7 @@ impl AgentManager { data_tag_names: &Vec, include_description: bool, ) -> Result, ShinkaiDBError> { - let resources = self.fetch_job_scope_resources(job_scope, profile).await?; + let resources = JobManager::fetch_job_scope_resources(db, job_scope, profile).await?; // Perform syntactic vector search on all resources let mut retrieved_chunks = Vec::new(); @@ -92,8 +94,7 @@ impl AgentManager { // Sort the retrieved chunks by score before returning let sorted_retrieved_chunks = RetrievedDataChunk::sort_by_score(&retrieved_chunks, num_of_results); - let updated_chunks = self - .include_description_retrieved_chunk(include_description, sorted_retrieved_chunks, &resources) + let updated_chunks = JobManager::include_description_retrieved_chunk(include_description, sorted_retrieved_chunks, &resources) .await; Ok(updated_chunks) @@ -103,7 +104,6 @@ impl AgentManager { /// that the top scored retrieved chunk is from, by prepending a fake RetrievedDataChunk /// with the description inside. Removes the lowest scored chunk to preserve list length. async fn include_description_retrieved_chunk( - &self, include_description: bool, sorted_retrieved_chunks: Vec, resources: &[BaseVectorResource], diff --git a/src/agent/file_parsing.rs b/src/agent/file_parsing.rs index 866509cf5..af45614d4 100644 --- a/src/agent/file_parsing.rs +++ b/src/agent/file_parsing.rs @@ -5,6 +5,7 @@ use lazy_static::lazy_static; use mupdf::pdf::PdfDocument; use regex::Regex; use sha2::{Digest, Sha256}; +use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_vector_resources::base_vector_resources::BaseVectorResource; use shinkai_vector_resources::document_resource::DocumentVectorResource; use shinkai_vector_resources::embedding_generator::EmbeddingGenerator; @@ -21,27 +22,29 @@ use std::sync::Arc; use std::{io::Cursor, vec}; use tokio::sync::Mutex; +use crate::db::ShinkaiDB; + use super::agent::Agent; use super::error::AgentError; use super::execution::job_prompts::{JobPromptGenerator, Prompt}; -use super::job_manager::AgentManager; +use super::job_manager::JobManager; lazy_static! { pub static ref UNSTRUCTURED_API_URL: &'static str = "https://internal.shinkai.com/"; } -impl AgentManager { +impl JobManager { /// Makes an async request to process a file in a buffer to Unstructured server, /// and then processing the returned results into a BaseVectorResource /// Note: The file name must include the extension ie. `*.pdf` pub async fn parse_file_into_resource( - &self, + db: Arc>, file_buffer: Vec, generator: &dyn EmbeddingGenerator, name: String, desc: Option, parsing_tags: &Vec, - agent: Arc>, + agent: SerializedAgent, max_chunk_size: u64, ) -> Result { // Parse file into needed data @@ -55,9 +58,7 @@ impl AgentManager { if desc.is_none() { let prompt = ParsingHelper::process_elements_into_description_prompt(&elements, 2000); desc = Some(ParsingHelper::ending_stripper( - &self - .inference_agent_and_extract(agent.clone(), prompt, "answer") - .await?, + &JobManager::inference_agent_and_extract(agent.clone(), prompt, "answer").await?, )); eprintln!("LLM Generated File Description: {:?}", desc); } diff --git a/src/agent/job_manager.rs b/src/agent/job_manager.rs index 92de68093..b3975bdd9 100644 --- a/src/agent/job_manager.rs +++ b/src/agent/job_manager.rs @@ -1,12 +1,14 @@ use super::error::AgentError; +use super::queue::job_queue_manager::{JobForProcessing, JobQueueManager}; use crate::agent::agent::Agent; pub use crate::agent::execution::job_execution_core::*; use crate::agent::job::{Job, JobId, JobLike}; -use crate::agent::plan_executor::PlanExecutor; use crate::db::{db_errors::ShinkaiDBError, ShinkaiDB}; use crate::managers::IdentityManager; -use chrono::Utc; +use crate::resources::bert_cpp::BertCPPProcess; use ed25519_dalek::SecretKey as SignatureStaticKey; +use futures::Future; +use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; use shinkai_message_primitives::{ schemas::shinkai_name::{ShinkaiName, ShinkaiNameError}, shinkai_message::{ @@ -15,106 +17,40 @@ use shinkai_message_primitives::{ }, shinkai_utils::{shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key}, }; -use std::fmt; +use std::collections::HashSet; +use std::mem; +use std::pin::Pin; use std::result::Result::Ok; use std::{collections::HashMap, error::Error, sync::Arc}; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, Mutex, Semaphore}; -pub struct JobManager { - pub agent_manager: Arc>, - pub job_manager_receiver: Arc, JobId)>>>, - pub job_manager_sender: mpsc::Sender<(Vec, JobId)>, - pub identity_secret_key: SignatureStaticKey, - pub node_profile_name: ShinkaiName, -} - -impl JobManager { - pub async fn new( - db: Arc>, - identity_manager: Arc>, - identity_secret_key: SignatureStaticKey, - node_profile_name: ShinkaiName, - ) -> Self { - let (job_manager_sender, job_manager_receiver) = tokio::sync::mpsc::channel(100); - let agent_manager = AgentManager::new( - db, - identity_manager, - job_manager_sender.clone(), - clone_signature_secret_key(&identity_secret_key), - ) - .await; - - let mut job_manager = Self { - agent_manager: Arc::new(Mutex::new(agent_manager)), - job_manager_receiver: Arc::new(Mutex::new(job_manager_receiver)), - job_manager_sender: job_manager_sender.clone(), - identity_secret_key, - node_profile_name, - }; - job_manager.process_received_messages().await; - job_manager - } - - pub async fn process_job_message(&mut self, shinkai_message: ShinkaiMessage) -> Result { - let mut agent_manager = self.agent_manager.lock().await; - if agent_manager.is_job_message(shinkai_message.clone()) { - agent_manager.process_job_message(shinkai_message).await - } else { - Err(AgentError::NotAJobMessage) - } - } - - pub async fn process_received_messages(&mut self) { - let agent_manager = Arc::clone(&self.agent_manager); - let receiver = Arc::clone(&self.job_manager_receiver); - let node_profile_name_clone = self.node_profile_name.clone(); - let identity_secret_key_clone = clone_signature_secret_key(&self.identity_secret_key); - tokio::spawn(async move { - while let Some((messages, job_id)) = receiver.lock().await.recv().await { - for message in messages { - eprintln!("Nico debug> process_received_messages"); - let mut agent_manager = agent_manager.lock().await; - - let shinkai_message_result = ShinkaiMessageBuilder::job_message_from_agent( - job_id.clone(), - message.content.clone(), - clone_signature_secret_key(&identity_secret_key_clone), - node_profile_name_clone.to_string(), - node_profile_name_clone.to_string(), - ); - - if let Ok(shinkai_message) = shinkai_message_result { - if let Err(err) = agent_manager - .handle_pre_message_schema(message, job_id.clone(), shinkai_message) - .await - { - eprintln!("Error while handling pre message schema: {:?}", err); - } - } else if let Err(err) = shinkai_message_result { - eprintln!("Error while building ShinkaiMessage: {:?}", err); - } - } - } - }); - } -} +const NUM_THREADS: usize = 2; -pub struct AgentManager { +pub struct JobManager { pub jobs: Arc>>>, pub db: Arc>, pub identity_manager: Arc>, - pub job_manager_sender: mpsc::Sender<(Vec, JobId)>, pub agents: Vec>>, pub identity_secret_key: SignatureStaticKey, + pub job_queue_manager: Arc>>, + pub node_profile_name: ShinkaiName, + pub job_processing_task: Option>, + _bert_process: BertCPPProcess, + + // TODO: remove them + pub job_manager_receiver: Arc, JobId)>>>, + pub job_manager_sender: mpsc::Sender<(Vec, JobId)>, } -impl AgentManager { +impl JobManager { pub async fn new( db: Arc>, identity_manager: Arc>, - job_manager_sender: mpsc::Sender<(Vec, JobId)>, identity_secret_key: SignatureStaticKey, + node_profile_name: ShinkaiName, ) -> Self { + let (job_manager_sender, job_manager_receiver) = mpsc::channel(100); + let jobs_map = Arc::new(Mutex::new(HashMap::new())); { let shinkai_db = db.lock().await; @@ -131,23 +67,208 @@ impl AgentManager { let identity_manager = identity_manager.lock().await; let serialized_agents = identity_manager.get_all_agents().await.unwrap(); for serialized_agent in serialized_agents { - let agent = Agent::from_serialized_agent(serialized_agent, job_manager_sender.clone()); + let agent = Agent::from_serialized_agent(serialized_agent); agents.push(Arc::new(Mutex::new(agent))); } } - let mut job_manager = Self { - jobs: jobs_map, - db, + let job_queue = JobQueueManager::::new(db.clone()).await.unwrap(); + let job_queue_manager = Arc::new(Mutex::new(job_queue)); + + let _bert_process = BertCPPProcess::start().unwrap(); // Gets killed if out of scope + // Start processing the job queue + let job_queue_handler = JobManager::process_job_queue( + job_queue_manager.clone(), + db.clone(), + NUM_THREADS, + clone_signature_secret_key(&identity_secret_key), + |job, db, identity_sk| Box::pin(JobManager::process_job_message_queued(job, db, identity_sk)), + ) + .await; + + let job_manager = Self { + db: db.clone(), + job_manager_receiver: Arc::new(Mutex::new(job_manager_receiver)), job_manager_sender: job_manager_sender.clone(), + identity_secret_key: clone_signature_secret_key(&identity_secret_key), + node_profile_name, + jobs: jobs_map, identity_manager, agents, - identity_secret_key, + job_queue_manager: job_queue_manager.clone(), + job_processing_task: Some(job_queue_handler), + _bert_process, }; job_manager } + pub async fn process_job_queue( + job_queue_manager: Arc>>, + db: Arc>, + max_parallel_jobs: usize, + identity_sk: SignatureStaticKey, + job_processing_fn: impl Fn( + JobForProcessing, + Arc>, + SignatureStaticKey, + ) -> Pin> + Send>> + + Send + + Sync + + 'static, + ) -> tokio::task::JoinHandle<()> { + let job_queue_manager = Arc::clone(&job_queue_manager); + let mut receiver = job_queue_manager.lock().await.subscribe_to_all().await; + let db_clone = db.clone(); + let identity_sk = clone_signature_secret_key(&identity_sk); + let job_processing_fn = Arc::new(job_processing_fn); + + let processing_jobs = Arc::new(Mutex::new(HashSet::new())); + let semaphore = Arc::new(Semaphore::new(max_parallel_jobs)); + + return tokio::spawn(async move { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Info, + "Starting job queue processing loop", + ); + + let mut handles = Vec::new(); + loop { + // Scope for acquiring and releasing the lock quickly + let job_ids_to_process: Vec = { + let mut processing_jobs_lock = processing_jobs.lock().await; + let job_queue_manager_lock = job_queue_manager.lock().await; + let all_jobs = job_queue_manager_lock + .get_all_elements_interleave() + .await + .unwrap_or(Vec::new()); + std::mem::drop(job_queue_manager_lock); + + let jobs = all_jobs + .into_iter() + .filter_map(|job| { + let job_id = job.job_message.job_id.clone().to_string(); + if !processing_jobs_lock.contains(&job_id) { + processing_jobs_lock.insert(job_id.clone()); + Some(job_id) + } else { + None + } + }) + .collect(); + + std::mem::drop(processing_jobs_lock); + jobs + }; + + // Spawn tasks based on filtered job IDs + for job_id in job_ids_to_process { + let job_queue_manager = Arc::clone(&job_queue_manager); + let processing_jobs = Arc::clone(&processing_jobs); + let semaphore = Arc::clone(&semaphore); + let db_clone_2 = db_clone.clone(); + let identity_sk_clone = clone_signature_secret_key(&identity_sk); + let job_processing_fn = Arc::clone(&job_processing_fn); + + let handle = tokio::spawn(async move { + let _permit = semaphore.acquire().await.unwrap(); + + // Acquire the lock, dequeue the job, and immediately release the lock + let job = { + let mut job_queue_manager = job_queue_manager.lock().await; + let job = job_queue_manager.peek(&job_id).await; + job + }; + + match job { + Ok(Some(job)) => { + // Acquire the lock, process the job, and immediately release the lock + let result = { + let result = job_processing_fn(job, db_clone_2, identity_sk_clone).await; + if let Ok(Some(_)) = job_queue_manager.lock().await.dequeue(&job_id.clone()).await { + result + } else { + Err(AgentError::JobDequeueFailed(job_id.clone())) + } + }; + + match result { + Ok(_) => { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Debug, + "Job processed successfully", + ); + } // handle success case + Err(e) => {} // handle error case + } + } + Ok(None) => {} + Err(e) => { + // Log the error + } + } + drop(_permit); + processing_jobs.lock().await.remove(&job_id); + }); + handles.push(handle); + } + + let handles_to_join = mem::replace(&mut handles, Vec::new()); + futures::future::join_all(handles_to_join).await; + handles.clear(); + + // Receive new jobs + if let Some(new_job) = receiver.recv().await { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Info, + format!("Received new job {:?}", new_job.job_message.job_id).as_str(), + ); + } + } + }); + } + + pub async fn process_job_message(&mut self, message: ShinkaiMessage) -> Result { + if self.is_job_message(message.clone()) { + match message.clone().body { + MessageBody::Unencrypted(body) => { + match body.message_data { + MessageData::Unencrypted(data) => { + let message_type = data.message_content_schema; + match message_type { + MessageSchemaType::JobCreationSchema => { + let agent_name = + ShinkaiName::from_shinkai_message_using_recipient_subidentity(&message)?; + let agent_id = agent_name.get_agent_name().ok_or(AgentError::AgentNotFound)?; + let job_creation: JobCreationInfo = serde_json::from_str(&data.message_raw_content) + .map_err(|_| AgentError::ContentParseFailed)?; + self.process_job_creation(job_creation, &agent_id).await + } + MessageSchemaType::JobMessageSchema => { + let job_message: JobMessage = serde_json::from_str(&data.message_raw_content) + .map_err(|_| AgentError::ContentParseFailed)?; + self.add_to_job_processing_queue(message, job_message).await + } + _ => { + // Handle Empty message type if needed, or return an error if it's not a valid job message + Err(AgentError::NotAJobMessage) + } + } + } + _ => Err(AgentError::NotAJobMessage), + } + } + _ => Err(AgentError::NotAJobMessage), + } + } else { + Err(AgentError::NotAJobMessage) + } + } + + // From JobManager /// Checks that the provided ShinkaiMessage is an unencrypted job message pub fn is_job_message(&mut self, message: ShinkaiMessage) -> bool { match &message.body { @@ -168,6 +289,7 @@ impl AgentManager { job_creation: JobCreationInfo, agent_id: &String, ) -> Result { + // TODO: add job_id to agent so it's aware let job_id = format!("jobid_{}", uuid::Uuid::new_v4()); { let mut shinkai_db = self.db.lock().await; @@ -192,7 +314,7 @@ impl AgentManager { if agent_found.is_none() { let identity_manager = self.identity_manager.lock().await; if let Some(serialized_agent) = identity_manager.search_local_agent(&agent_id).await { - let agent = Agent::from_serialized_agent(serialized_agent, self.job_manager_sender.clone()); + let agent = Agent::from_serialized_agent(serialized_agent); agent_found = Some(Arc::new(Mutex::new(agent))); self.agents.push(agent_found.clone().unwrap()); } @@ -212,52 +334,32 @@ impl AgentManager { } } - /// Adds pre-message to job inbox - pub async fn handle_pre_message_schema( + pub async fn add_to_job_processing_queue( &mut self, - pre_message: JobPreMessage, - job_id: String, - shinkai_message: ShinkaiMessage, + message: ShinkaiMessage, + job_message: JobMessage, ) -> Result { - println!("handle_pre_message_schema> pre_message: {:?}", pre_message); + // Verify identity/profile match + let sender_subidentity_result = ShinkaiName::from_shinkai_message_using_sender_subidentity(&message.clone()); + let sender_subidentity = match sender_subidentity_result { + Ok(subidentity) => subidentity, + Err(e) => return Err(AgentError::InvalidSubidentity(e)), + }; + let profile_result = sender_subidentity.extract_profile(); + let profile = match profile_result { + Ok(profile) => profile, + Err(e) => return Err(AgentError::InvalidProfileSubidentity(e.to_string())), + }; - self.db - .lock() - .await - .add_message_to_job_inbox(job_id.as_str(), &shinkai_message)?; - Ok(String::new()) - } + let mut shinkai_db = self.db.lock().await; + shinkai_db.add_message_to_job_inbox(&job_message.job_id.clone(), &message)?; + std::mem::drop(shinkai_db); - pub async fn process_job_message(&mut self, message: ShinkaiMessage) -> Result { - match message.clone().body { - MessageBody::Unencrypted(body) => { - match body.message_data { - MessageData::Unencrypted(data) => { - let message_type = data.message_content_schema; - match message_type { - MessageSchemaType::JobCreationSchema => { - let agent_name = - ShinkaiName::from_shinkai_message_using_recipient_subidentity(&message)?; - let agent_id = agent_name.get_agent_name().ok_or(AgentError::AgentNotFound)?; - let job_creation: JobCreationInfo = serde_json::from_str(&data.message_raw_content) - .map_err(|_| AgentError::ContentParseFailed)?; - self.process_job_creation(job_creation, &agent_id).await - } - MessageSchemaType::JobMessageSchema => { - let job_message: JobMessage = serde_json::from_str(&data.message_raw_content) - .map_err(|_| AgentError::ContentParseFailed)?; - self.process_job_step(message, job_message).await - } - _ => { - // Handle Empty message type if needed, or return an error if it's not a valid job message - Err(AgentError::NotAJobMessage) - } - } - } - _ => Err(AgentError::NotAJobMessage), - } - } - _ => Err(AgentError::NotAJobMessage), - } + let job_for_processing = JobForProcessing::new(job_message.clone(), profile.clone()); + + let mut job_queue_manager = self.job_queue_manager.lock().await; + let _ = job_queue_manager.push(&job_message.job_id, job_for_processing).await; + + Ok(job_message.job_id.clone().to_string()) } } diff --git a/src/agent/queue/job_queue_manager.rs b/src/agent/queue/job_queue_manager.rs index 18d91e727..4f8a52b6b 100644 --- a/src/agent/queue/job_queue_manager.rs +++ b/src/agent/queue/job_queue_manager.rs @@ -1,32 +1,92 @@ use crate::db::db_errors::ShinkaiDBError; use crate::db::ShinkaiDB; +use chrono::{DateTime, Utc}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobMessage; +use std::cmp::Ordering; use std::collections::HashMap; -use std::sync::{mpsc, Arc, Mutex}; +use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; type MutexQueue = Arc>>; type Subscriber = mpsc::Sender; -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +// First Type + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub struct JobForProcessing { - job_message: JobMessage, - profile: ShinkaiName, + pub job_message: JobMessage, + pub profile: ShinkaiName, + pub date_created: String, +} + +impl JobForProcessing { + pub fn new(job_message: JobMessage, profile: ShinkaiName) -> Self { + JobForProcessing { + job_message, + profile, + date_created: Utc::now().to_rfc3339(), + } + } +} + +impl PartialOrd for JobForProcessing { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for JobForProcessing { + fn cmp(&self, other: &Self) -> Ordering { + self.date_created.cmp(&other.date_created) + } +} + +// Second Type +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct OrdJsonValue(JsonValue); + +impl Ord for OrdJsonValue { + fn cmp(&self, other: &Self) -> Ordering { + let self_str = self.0.to_string(); + let other_str = other.0.to_string(); + self_str.cmp(&other_str) + } +} + +impl PartialOrd for OrdJsonValue { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for OrdJsonValue {} + +impl PartialEq for OrdJsonValue { + fn eq(&self, other: &Self) -> bool { + let self_str = self.0.to_string(); + let other_str = other.0.to_string(); + self_str == other_str + } } #[derive(Debug)] -pub struct JobQueueManager { - queues: HashMap>, - subscribers: HashMap>>, +pub struct JobQueueManager { + queues: Arc>>>, + subscribers: Arc>>>>, + all_subscribers: Arc>>>, db: Arc>, } -impl JobQueueManager { - pub fn new(db: Arc>) -> Result { +impl JobQueueManager { + pub async fn new(db: Arc>) -> Result { // Lock the db for safe access - let db_lock = db.lock().unwrap(); + let db_lock = db.lock().await; // Call the get_all_queues method to get all queue data from the db match db_lock.get_all_queues() { @@ -39,8 +99,9 @@ impl JobQueueManager JobQueueManager Result, ShinkaiDBError> { - let db = self.db.lock().unwrap(); + async fn get_queue(&self, key: &str) -> Result, ShinkaiDBError> { + let db = self.db.lock().await; db.get_job_queues(key) } - pub fn push(&mut self, key: &str, value: T) -> Result<(), ShinkaiDBError> { - let queue = self - .queues + pub async fn push(&mut self, key: &str, value: T) -> Result<(), ShinkaiDBError> { + // Lock the Mutex to get mutable access to the HashMap + let mut queues = self.queues.lock().await; + + // Ensure the specified key exists in the queues hashmap, initializing it with an empty queue if necessary + let queue = queues .entry(key.to_string()) .or_insert_with(|| Arc::new(Mutex::new(Vec::new()))); - let mut guarded_queue = queue.lock().unwrap(); + let mut guarded_queue = queue.lock().await; guarded_queue.push(value.clone()); // Persist queue to the database - let db = self.db.lock().unwrap(); + let db = self.db.lock().await; db.persist_queue(key, &guarded_queue)?; // Notify subscribers - if let Some(subs) = self.subscribers.get(key) { + let subscribers = self.subscribers.lock().await; + if let Some(subs) = subscribers.get(key) { for sub in subs.iter() { - sub.send(value.clone()).unwrap(); + let _ = sub.send(value.clone()).await; } } + + // Notify subscribers to all keys + let all_subscribers = self.all_subscribers.lock().await; + for sub in all_subscribers.iter() { + let _ = sub.send(value.clone()).await; + } Ok(()) } - pub fn dequeue(&mut self, key: &str) -> Result, ShinkaiDBError> { + pub async fn dequeue(&mut self, key: &str) -> Result, ShinkaiDBError> { + // Lock the Mutex to get mutable access to the HashMap + let mut queues = self.queues.lock().await; + // Ensure the specified key exists in the queues hashmap, initializing it with an empty queue if necessary - let queue = self - .queues + let queue = queues .entry(key.to_string()) .or_insert_with(|| Arc::new(Mutex::new(Vec::new()))); - let mut guarded_queue = queue.lock().unwrap(); + + let mut guarded_queue = queue.lock().await; // Check if there's an element to dequeue, and remove it if so let result = if guarded_queue.get(0).is_some() { @@ -91,27 +165,79 @@ impl JobQueueManager mpsc::Receiver { - let (tx, rx) = mpsc::channel(); - self.subscribers - .entry(key.to_string()) - .or_insert_with(Vec::new) - .push(tx); + pub async fn peek(&self, key: &str) -> Result, ShinkaiDBError> { + let queues = self.queues.lock().await; + if let Some(queue) = queues.get(key) { + let guarded_queue = queue.lock().await; + if let Some(first) = guarded_queue.first() { + return Ok(Some(first.clone())); + } + } + Ok(None) + } + + pub async fn get_all_elements_interleave(&self) -> Result, ShinkaiDBError> { + let db_lock = self.db.lock().await; + let mut db_queues: HashMap<_, _> = db_lock.get_all_queues::()?; + + // Sort the keys based on the first element in each queue, falling back to key names + let mut keys: Vec<_> = db_queues.keys().cloned().collect(); + keys.sort_by(|a, b| { + let a_first = db_queues.get(a).and_then(|q| q.first()); + let b_first = db_queues.get(b).and_then(|q| q.first()); + match (a_first, b_first) { + (Some(a), Some(b)) => a.cmp(b), + _ => a.cmp(b), + } + }); + + let mut all_elements = Vec::new(); + let mut indices: Vec<_> = vec![0; keys.len()]; + let mut added = true; + + while added { + added = false; + for (key, index) in keys.iter().zip(indices.iter_mut()) { + if let Some(queue) = db_queues.get_mut(key) { + if let Some(element) = queue.get(*index) { + all_elements.push(element.clone()); + *index += 1; + added = true; + } + } + } + } + + Ok(all_elements) + } + + pub async fn subscribe(&self, key: &str) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel(100); + let mut subscribers = self.subscribers.lock().await; + subscribers.entry(key.to_string()).or_insert_with(Vec::new).push(tx); + rx + } + + pub async fn subscribe_to_all(&self) -> mpsc::Receiver { + let (tx, rx) = mpsc::channel(100); + let mut all_subscribers = self.all_subscribers.lock().await; + all_subscribers.push(tx); rx } } -impl Clone for JobQueueManager { +impl Clone for JobQueueManager { fn clone(&self) -> Self { JobQueueManager { - queues: self.queues.clone(), - subscribers: self.subscribers.clone(), + queues: Arc::clone(&self.queues), + subscribers: Arc::clone(&self.subscribers), + all_subscribers: Arc::clone(&self.all_subscribers), db: Arc::clone(&self.db), } } @@ -132,21 +258,24 @@ mod tests { let _ = fs::remove_dir_all(&path); } - #[test] - fn test_queue_manager() { + #[tokio::test] + async fn test_queue_manager() { setup(); let db = Arc::new(Mutex::new(ShinkaiDB::new("db_tests/").unwrap())); - let mut manager = JobQueueManager::::new(db).unwrap(); + let mut manager = JobQueueManager::::new(db).await.unwrap(); // Subscribe to notifications from "my_queue" - let receiver = manager.subscribe("my_queue"); + let mut receiver = manager.subscribe("job_id::123::false").await; let mut manager_clone = manager.clone(); - let handle = std::thread::spawn(move || { - for msg in receiver.iter() { + let handle = tokio::spawn(async move { + while let Some(msg) = receiver.recv().await { println!("Received (from subscriber): {:?}", msg); + let results = manager_clone.get_all_elements_interleave().await.unwrap(); + eprintln!("All elements: {:?}", results); + // Dequeue from the queue inside the subscriber thread - if let Ok(Some(message)) = manager_clone.dequeue("my_queue") { + if let Ok(Some(message)) = manager_clone.dequeue("job_id::123::false").await { println!("Dequeued (from subscriber): {:?}", message); // Assert that the subscriber dequeued the correct message @@ -155,7 +284,7 @@ mod tests { eprintln!("Dequeued (from subscriber): {:?}", msg); // Assert that the queue is now empty - match manager_clone.dequeue("my_queue") { + match manager_clone.dequeue("job_id::123::false").await { Ok(None) => (), Ok(Some(_)) => panic!("Queue is not empty!"), Err(e) => panic!("Failed to dequeue from queue: {:?}", e), @@ -165,57 +294,57 @@ mod tests { }); // Push to a queue - let job = JobForProcessing { - job_message: JobMessage { + let job = JobForProcessing::new( + JobMessage { job_id: "job_id::123::false".to_string(), content: "my content".to_string(), files_inbox: "".to_string(), }, - profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), - }; - manager.push("my_queue", job.clone()).unwrap(); + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + manager.push("job_id::123::false", job.clone()).await.unwrap(); // Sleep to allow subscriber to process the message (just for this example) - std::thread::sleep(std::time::Duration::from_millis(500)); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; - handle.join().unwrap(); + handle.await.unwrap(); } - #[test] - fn test_queue_manager_consistency() { + #[tokio::test] + async fn test_queue_manager_consistency() { setup(); let db_path = "db_tests/"; let db = Arc::new(Mutex::new(ShinkaiDB::new(db_path).unwrap())); - let mut manager = JobQueueManager::::new(Arc::clone(&db)).unwrap(); + let mut manager = JobQueueManager::::new(Arc::clone(&db)).await.unwrap(); // Push to a queue - let job = JobForProcessing { - job_message: JobMessage { + let job = JobForProcessing::new( + JobMessage { job_id: "job_id::123::false".to_string(), content: "my content".to_string(), files_inbox: "".to_string(), }, - profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), - }; - let job2 = JobForProcessing { - job_message: JobMessage { + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + let job2 = JobForProcessing::new( + JobMessage { job_id: "job_id::123::false".to_string(), content: "my content 2".to_string(), files_inbox: "".to_string(), }, - profile: ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), - }; - manager.push("my_queue", job.clone()).unwrap(); - manager.push("my_queue", job2.clone()).unwrap(); + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + manager.push("my_queue", job.clone()).await.unwrap(); + manager.push("my_queue", job2.clone()).await.unwrap(); // Sleep to allow subscriber to process the message (just for this example) - std::thread::sleep(std::time::Duration::from_millis(500)); + tokio::time::sleep(std::time::Duration::from_millis(500)); // Create a new manager and recover the state - let mut new_manager = JobQueueManager::::new(Arc::clone(&db)).unwrap(); + let mut new_manager = JobQueueManager::::new(Arc::clone(&db)).await.unwrap(); // Try to pop the job from the queue using the new manager - match new_manager.dequeue("my_queue") { + match new_manager.dequeue("my_queue").await { Ok(Some(recovered_job)) => { shinkai_log( ShinkaiLogOption::Tests, @@ -228,7 +357,7 @@ mod tests { Err(e) => panic!("Failed to pop job from queue: {:?}", e), } - match new_manager.dequeue("my_queue") { + match new_manager.dequeue("my_queue").await { Ok(Some(recovered_job)) => { shinkai_log( ShinkaiLogOption::Tests, @@ -242,21 +371,21 @@ mod tests { } } - #[test] - fn test_queue_manager_with_jsonvalue() { + #[tokio::test] + async fn test_queue_manager_with_jsonvalue() { setup(); let db = Arc::new(Mutex::new(ShinkaiDB::new("db_tests/").unwrap())); - let mut manager = JobQueueManager::>::new(db).unwrap(); + let mut manager = JobQueueManager::::new(db).await.unwrap(); // Subscribe to notifications from "my_queue" - let receiver = manager.subscribe("my_queue"); + let mut receiver = manager.subscribe("my_queue").await; let mut manager_clone = manager.clone(); - let handle = std::thread::spawn(move || { - for msg in receiver.iter() { + let handle = tokio::spawn(async move { + while let Some(msg) = receiver.recv().await { println!("Received (from subscriber): {:?}", msg); // Dequeue from the queue inside the subscriber thread - if let Ok(Some(message)) = manager_clone.dequeue("my_queue") { + if let Ok(Some(message)) = manager_clone.dequeue("my_queue").await { println!("Dequeued (from subscriber): {:?}", message); // Assert that the subscriber dequeued the correct message @@ -265,7 +394,7 @@ mod tests { eprintln!("Dequeued (from subscriber): {:?}", msg); // Assert that the queue is now empty - match manager_clone.dequeue("my_queue") { + match manager_clone.dequeue("my_queue").await { Ok(None) => (), Ok(Some(_)) => panic!("Queue is not empty!"), Err(e) => panic!("Failed to dequeue from queue: {:?}", e), @@ -275,11 +404,85 @@ mod tests { }); // Push to a queue - let job = Ok(JsonValue::String("my content".to_string())); - manager.push("my_queue", job.clone()).unwrap(); + let job = JsonValue::String("my content".to_string()); + manager.push("my_queue", OrdJsonValue(job)).await.unwrap(); // Sleep to allow subscriber to process the message (just for this example) - std::thread::sleep(std::time::Duration::from_millis(500)); - handle.join().unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + handle.await.unwrap(); + } + + #[tokio::test] + async fn test_get_all_elements_interleave() { + setup(); + let db = Arc::new(Mutex::new(ShinkaiDB::new("db_tests/").unwrap())); + let mut manager = JobQueueManager::::new(db).await.unwrap(); + + // Create jobs + let job_a1 = JobForProcessing::new( + JobMessage { + job_id: "job_id::a1::false".to_string(), + content: "content a1".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + let job_a2 = JobForProcessing::new( + JobMessage { + job_id: "job_id::a2::false".to_string(), + content: "content a2".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + let job_a3 = JobForProcessing::new( + JobMessage { + job_id: "job_id::a3::false".to_string(), + content: "content a3".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + + let job_b1 = JobForProcessing::new( + JobMessage { + job_id: "job_id::b1::false".to_string(), + content: "content b1".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + + let job_c1 = JobForProcessing::new( + JobMessage { + job_id: "job_id::c1::false".to_string(), + content: "content c1".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + + let job_c2 = JobForProcessing::new( + JobMessage { + job_id: "job_id::c2::false".to_string(), + content: "content c2".to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + + // Push jobs to queues + manager.push("job_a", job_a1.clone()).await.unwrap(); + manager.push("job_a", job_a2.clone()).await.unwrap(); + manager.push("job_a", job_a3.clone()).await.unwrap(); + manager.push("job_b", job_b1.clone()).await.unwrap(); + manager.push("job_c", job_c1.clone()).await.unwrap(); + manager.push("job_c", job_c2.clone()).await.unwrap(); + + // Get all elements interleaved + let all_elements = manager.get_all_elements_interleave().await.unwrap(); + + // Check if the elements are in the correct order + assert_eq!(all_elements, vec![job_a1, job_b1, job_c1, job_a2, job_c2, job_a3]); } } diff --git a/tests/db_agents_tests.rs b/tests/db_agents_tests.rs index 783485978..bd5e7e3d0 100644 --- a/tests/db_agents_tests.rs +++ b/tests/db_agents_tests.rs @@ -220,14 +220,12 @@ mod tests { ) .create(); - let (tx, _rx) = mpsc::channel(1); let openai = OpenAI { model_type: "gpt-3.5-turbo".to_string(), }; let agent = Agent::new( "1".to_string(), ShinkaiName::new("@@alice.shinkai/profileName/agent/myChatGPTAgent".to_string()).unwrap(), - tx, false, Some(server.url()), // use the url of the mock server Some("mockapikey".to_string()), diff --git a/tests/encrypted_files_tests.rs b/tests/encrypted_files_tests.rs index f8b987e73..c7baf8f10 100644 --- a/tests/encrypted_files_tests.rs +++ b/tests/encrypted_files_tests.rs @@ -21,6 +21,7 @@ use shinkai_message_primitives::shinkai_utils::signatures::{ }; use shinkai_message_primitives::shinkai_utils::utils::hash_string; use shinkai_node::agent::agent; +use shinkai_node::agent::error::AgentError; use shinkai_node::network::node::NodeCommand; use shinkai_node::network::node_api::APIError; use shinkai_node::network::Node; @@ -38,7 +39,6 @@ use crate::utils::node_test_api::{ api_agent_registration, api_create_job, api_get_all_inboxes_from_profile, api_initial_registration_with_no_code_for_device, api_message_job, api_registration_device_node_profile_main, }; -use crate::utils::node_test_local::local_registration_profile_node; use mockito::Server; #[test] @@ -356,10 +356,11 @@ fn sandwich_messages_with_files_test() { // ) // .create(); // } + + let job_message_content = "What's Zeko?".to_string(); { // Send a Message to the Job for processing eprintln!("\n\nSend a message for the Job"); - let message = "What's Zeko?".to_string(); let start = Instant::now(); api_message_job( node1_commands_sender.clone(), @@ -370,7 +371,7 @@ fn sandwich_messages_with_files_test() { node1_profile_name.clone().as_str(), &agent_subidentity.clone(), &job_id.clone().to_string(), - &message, + &job_message_content, &hash_of_aes_encryption_key_hex(symmetrical_sk), ) .await; @@ -378,6 +379,45 @@ fn sandwich_messages_with_files_test() { let duration = start.elapsed(); // Get the time elapsed since the start of the timer eprintln!("Time elapsed in api_message_job is: {:?}", duration); } + // { + // tokio::time::sleep(Duration::from_secs(1000)).await; + // } + { + eprintln!("Waiting for the Job to finish"); + for _ in 0..50 { + let (res1_sender, res1_receiver) = async_channel::bounded(1); + node1_commands_sender + .send(NodeCommand::FetchLastMessages { + limit: 2, + res: res1_sender, + }) + .await + .unwrap(); + let node1_last_messages = res1_receiver.recv().await.unwrap(); + eprintln!("node1_last_messages: {:?}", node1_last_messages); + + match node1_last_messages[0].get_message_content() { + Ok(message_content) => { + match serde_json::from_str::(&message_content) { + Ok(job_message) => { + eprintln!("message_content: {}", message_content); + if job_message.content != job_message_content { + assert!(true); + break; + } + } + Err(_) => { + eprintln!("error: message_content: {}", message_content); + } + } + } + Err(_) => { + // nothing + } + } + tokio::time::sleep(Duration::from_secs(10)).await; + } + } }) }); } diff --git a/tests/job_manager_concurrency_tests.rs b/tests/job_manager_concurrency_tests.rs new file mode 100644 index 000000000..17edca7ef --- /dev/null +++ b/tests/job_manager_concurrency_tests.rs @@ -0,0 +1,260 @@ +use ed25519_dalek::{PublicKey as SignaturePublicKey, SecretKey as SignatureStaticKey}; +use futures::Future; +use shinkai_message_primitives::schemas::inbox_name::InboxName; +use shinkai_message_primitives::shinkai_utils::encryption::{ + unsafe_deterministic_encryption_keypair, EncryptionMethod, +}; +use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; +use shinkai_message_primitives::shinkai_utils::signatures::unsafe_deterministic_signature_keypair; +use shinkai_message_primitives::shinkai_utils::utils::hash_string; +use shinkai_message_primitives::{ + schemas::shinkai_name::{ShinkaiName, ShinkaiNameError}, + shinkai_message::{ + shinkai_message::{MessageBody, MessageData, ShinkaiMessage}, + shinkai_message_schemas::{JobCreationInfo, JobMessage, JobPreMessage, MessageSchemaType}, + }, + shinkai_utils::{shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key}, +}; +use shinkai_node::agent::job_manager::JobManager; +use shinkai_node::agent::queue::job_queue_manager::{JobForProcessing, JobQueueManager}; +use shinkai_node::db::ShinkaiDB; +use std::collections::HashSet; +use std::pin::Pin; +use std::result::Result::Ok; +use std::time::{Duration, Instant}; +use std::{collections::HashMap, error::Error, sync::Arc}; +use tokio::sync::{mpsc, Mutex, Semaphore}; +use x25519_dalek::{PublicKey as EncryptionPublicKey, StaticSecret as EncryptionStaticKey}; + +mod utils; + +fn generate_message_with_text( + content: String, + my_encryption_secret_key: EncryptionStaticKey, + my_signature_secret_key: SignatureStaticKey, + receiver_public_key: EncryptionPublicKey, + recipient_subidentity_name: String, + origin_destination_identity_name: String, + timestamp: String, +) -> ShinkaiMessage { + let inbox_name = InboxName::get_regular_inbox_name_from_params( + origin_destination_identity_name.clone().to_string(), + "".to_string(), + origin_destination_identity_name.clone().to_string(), + recipient_subidentity_name.clone().to_string(), + false, + ) + .unwrap(); + + let inbox_name_value = match inbox_name { + InboxName::RegularInbox { value, .. } | InboxName::JobInbox { value, .. } => value, + }; + + let message = ShinkaiMessageBuilder::new(my_encryption_secret_key, my_signature_secret_key, receiver_public_key) + .message_raw_content(content.to_string()) + .body_encryption(EncryptionMethod::None) + .message_schema_type(MessageSchemaType::TextContent) + .internal_metadata_with_inbox( + "".to_string(), + recipient_subidentity_name.clone().to_string(), + inbox_name_value, + EncryptionMethod::None, + ) + .external_metadata_with_schedule( + origin_destination_identity_name.clone().to_string(), + origin_destination_identity_name.clone().to_string(), + timestamp, + ) + .build() + .unwrap(); + message +} + +#[tokio::test] +async fn test_process_job_queue_concurrency() { + utils::db_handlers::setup(); + + let NUM_THREADS = 8; + let db_path = "db_tests/"; + let db = Arc::new(Mutex::new(ShinkaiDB::new(db_path).unwrap())); + let (node_identity_sk, _) = unsafe_deterministic_signature_keypair(0); + + // Mock job processing function + let mock_processing_fn = |job: JobForProcessing, db: Arc>, _: SignatureStaticKey| { + Box::pin(async move { + shinkai_log( + ShinkaiLogOption::Tests, + ShinkaiLogLevel::Debug, + format!("Processing job: {:?}", job.job_message.content).as_str(), + ); + tokio::time::sleep(Duration::from_millis(200)).await; + + let (node1_identity_sk, _) = unsafe_deterministic_signature_keypair(0); + let (node1_encryption_sk, node1_encryption_pk) = unsafe_deterministic_encryption_keypair(0); + + // Create a message + let message = generate_message_with_text( + job.job_message.content, + node1_encryption_sk.clone(), + clone_signature_secret_key(&node1_identity_sk), + node1_encryption_pk, + "".to_string(), + "@@node1.shinkai".to_string(), + "2023-07-02T20:53:34.812Z".to_string(), + ); + + // Write the message to an inbox with the job name + let mut db = db.lock().await; + let _ = db.unsafe_insert_inbox_message(&message.clone()); + + Ok("Success".to_string()) + }) + }; + + let mut job_queue = JobQueueManager::::new(Arc::clone(&db)).await.unwrap(); + let job_queue_manager = Arc::new(Mutex::new(job_queue.clone())); + + // Start processing the queue with concurrency + let job_queue_handler = JobManager::process_job_queue( + job_queue_manager, + db.clone(), + NUM_THREADS, + clone_signature_secret_key(&node_identity_sk), + move |job, db, identity_sk| mock_processing_fn(job, db, identity_sk), + ) + .await; + + // Enqueue multiple jobs + for i in 0..8 { + let job = JobForProcessing::new( + JobMessage { + job_id: format!("job_id::{}::false", i).to_string(), + content: format!("my content {}", i).to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + job_queue + .push(format!("job_id::{}::false", i).as_str(), job) + .await + .unwrap(); + } + + // Create a new task that lasts at least 2 seconds + let long_running_task = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(400)).await; + + let last_messages_all = db.lock().await.get_last_messages_from_all(10).unwrap(); + assert_eq!(last_messages_all.len(), 8); + }); + + // Set a timeout for both tasks to complete + let timeout_duration = Duration::from_millis(400); + let job_queue_handler_result = tokio::time::timeout(timeout_duration, job_queue_handler).await; + let long_running_task_result = tokio::time::timeout(timeout_duration, long_running_task).await; + + // Check the results of the tasks + match job_queue_handler_result { + Ok(_) => (), + Err(_) => (), + } + + match long_running_task_result { + Ok(_) => (), + Err(_) => (), + } +} + +#[tokio::test] +async fn test_sequnetial_process_for_same_job_id() { + utils::db_handlers::setup(); + + let NUM_THREADS = 8; + let db_path = "db_tests/"; + let db = Arc::new(Mutex::new(ShinkaiDB::new(db_path).unwrap())); + let (node_identity_sk, _) = unsafe_deterministic_signature_keypair(0); + + // Mock job processing function + let mock_processing_fn = |job: JobForProcessing, db: Arc>, _: SignatureStaticKey| { + Box::pin(async move { + shinkai_log( + ShinkaiLogOption::Tests, + ShinkaiLogLevel::Debug, + format!("Processing job: {:?}", job.job_message.content).as_str(), + ); + tokio::time::sleep(Duration::from_millis(200)).await; + + let (node1_identity_sk, _) = unsafe_deterministic_signature_keypair(0); + let (node1_encryption_sk, node1_encryption_pk) = unsafe_deterministic_encryption_keypair(0); + + // Create a message + let message = generate_message_with_text( + job.job_message.content, + node1_encryption_sk.clone(), + clone_signature_secret_key(&node1_identity_sk), + node1_encryption_pk, + "".to_string(), + "@@node1.shinkai".to_string(), + "2023-07-02T20:53:34.812Z".to_string(), + ); + + // Write the message to an inbox with the job name + let mut db = db.lock().await; + let _ = db.unsafe_insert_inbox_message(&message.clone()); + + Ok("Success".to_string()) + }) + }; + + let mut job_queue = JobQueueManager::::new(Arc::clone(&db)).await.unwrap(); + let job_queue_manager = Arc::new(Mutex::new(job_queue.clone())); + + // Start processing the queue with concurrency + let job_queue_handler = JobManager::process_job_queue( + job_queue_manager, + db.clone(), + NUM_THREADS, + clone_signature_secret_key(&node_identity_sk), + move |job, db, identity_sk| mock_processing_fn(job, db, identity_sk), + ) + .await; + + for i in 0..8 { + let job = JobForProcessing::new( + JobMessage { + job_id: "job_id::123::false".to_string(), + content: format!("my content {}", i).to_string(), + files_inbox: "".to_string(), + }, + ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(), + ); + job_queue + .push("job_id::123::false", job) + .await + .unwrap(); + } + + // Create a new task that lasts at least 2 seconds + let long_running_task = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(300)).await; + + let last_messages_all = db.lock().await.get_last_messages_from_all(10).unwrap(); + assert_eq!(last_messages_all.len(), 1); + }); + + // Set a timeout for both tasks to complete + let timeout_duration = Duration::from_millis(400); + let job_queue_handler_result = tokio::time::timeout(timeout_duration, job_queue_handler).await; + let long_running_task_result = tokio::time::timeout(timeout_duration, long_running_task).await; + + // Check the results of the tasks + match job_queue_handler_result { + Ok(_) => (), + Err(_) => (), + } + + match long_running_task_result { + Ok(_) => (), + Err(_) => (), + } +} diff --git a/tests/node_integration_tests.rs b/tests/node_integration_tests.rs index 425c49fd1..272838d49 100644 --- a/tests/node_integration_tests.rs +++ b/tests/node_integration_tests.rs @@ -328,7 +328,7 @@ fn subidentity_registration() { ); { - println!("Checking that the message has the right sender {:?}", message_to_check); + eprintln!("Checking that the message has the right sender {:?}", message_to_check); assert_eq!( message_to_check.get_sender_subidentity().unwrap(), node2_profile_name.to_string(),