diff --git a/Cargo.lock b/Cargo.lock index 485bd9a3f..b55163969 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4707,7 +4707,7 @@ dependencies = [ [[package]] name = "shinkai_node" -version = "0.5.0" +version = "0.5.3" dependencies = [ "aes-gcm", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 7666edcb8..9f43ff106 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,7 @@ [package] name = "shinkai_node" -version = "0.5.2" +version = "0.5.3" edition = "2018" -authors = ["Nico Arqueros "] [features] default = ["telemetry"] diff --git a/src/agent/agent.rs b/src/agent/agent.rs index b8dc3cbda..eba1198c5 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -1,8 +1,6 @@ use super::execution::job_prompts::{Prompt, SubPromptType}; use super::providers::LLMProvider; use super::{error::AgentError, job_manager::JobManager}; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; use reqwest::Client; use serde_json::{Map, Value as JsonValue}; use shinkai_message_primitives::{ @@ -86,9 +84,6 @@ impl Agent { let mut response = self.internal_inference_matching_model(prompt.clone()).await; let mut attempts = 0; - let mut rng = StdRng::from_entropy(); // This uses the system's source of entropy to seed the RNG - let random_number: i32 = rng.gen(); - let mut new_prompt = prompt.clone(); while let Err(err) = &response { if attempts >= 3 { @@ -96,9 +91,8 @@ impl Agent { } attempts += 1; - // If serde failed parsing the json string, then use advanced rertrying + // If serde failed parsing the json string, then use advanced retrying if let AgentError::FailedSerdeParsingJSONString(response_json, serde_error) = err { - println!("101 - {} - Failed parsing json of: {}", random_number, response_json); new_prompt.add_content(response_json.to_string(), SubPromptType::Assistant, 100); new_prompt.add_content( format!( @@ -118,11 +112,6 @@ impl Agent { let cleaned_json = JobManager::convert_inference_response_to_internal_strings(response?); - println!( - "105 - {} - Succeeded parsing json of: {}", - random_number, - cleaned_json.to_string() - ); Ok(cleaned_json) } diff --git a/src/agent/execution/chains/qa_inference_chain.rs b/src/agent/execution/chains/qa_inference_chain.rs index 00a4eaa90..2bae06572 100644 --- a/src/agent/execution/chains/qa_inference_chain.rs +++ b/src/agent/execution/chains/qa_inference_chain.rs @@ -6,6 +6,7 @@ use crate::agent::job_manager::JobManager; use crate::db::ShinkaiDB; use crate::vector_fs::vector_fs::VectorFS; use async_recursion::async_recursion; +use serde_json::Value as JsonValue; use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption}; @@ -54,13 +55,19 @@ impl JobManager { true, ) .await?; + // Text from the first node, which is the summary of the most similar VR + let summary_node_text = ret_nodes + .get(0) + .and_then(|node| node.node.get_text_content().ok()) + .map(|text| text.to_string()); // Use the default prompt if not reached final iteration count, else use final prompt - let filled_prompt = if iteration_count < max_iterations && !full_job.scope.is_empty() { + let is_not_final = iteration_count < max_iterations && !full_job.scope.is_empty(); + let filled_prompt = if is_not_final { JobPromptGenerator::response_prompt_with_vector_search( job_task.clone(), ret_nodes, - summary_text, + summary_text.clone(), Some(query_text), Some(full_job.step_history.clone()), ) @@ -68,24 +75,55 @@ impl JobManager { JobPromptGenerator::response_prompt_with_vector_search_final( job_task.clone(), ret_nodes, - summary_text, + summary_text.clone(), Some(full_job.step_history.clone()), ) }; - // Inference the agent's LLM with the prompt. If it has an answer, the chain - // is finished and so just return the answer response as a cleaned String - let response_json = JobManager::inference_agent(agent.clone(), filled_prompt.clone()).await?; - if let Ok(answer_str) = JobManager::direct_extract_key_inference_json_response(response_json.clone(), "answer") - { + // Inference the agent's LLM with the prompt + let response = JobManager::inference_agent(agent.clone(), filled_prompt.clone()).await; + // Check if it failed to produce a proper json object at all, and if so go through more advanced retry logic + if response.is_err() { + return no_json_object_retry_logic( + response, + db, + vector_fs, + full_job, + job_task, + agent, + execution_context, + generator, + user_profile, + summary_text, + summary_node_text, // This needs to be defined or passed appropriately + iteration_count, + max_iterations, + ) + .await; + } + + // Extract the JSON from the inference response Result and proceed forward + let response_json = response?; + let answer = JobManager::direct_extract_key_inference_json_response(response_json.clone(), "answer"); + + // If it has an answer, the chain is finished and so just return the answer response as a cleaned String + if let Ok(answer_str) = answer { let cleaned_answer = ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper(&answer_str)); - // println!("QA Chain Final Answer: {:?}", cleaned_answer); return Ok(cleaned_answer); } - // If iteration_count is > max_iterations and we still don't have an answer, return an error - else if iteration_count > max_iterations { - return Err(AgentError::InferenceRecursionLimitReached(job_task.clone())); + // If it errored and past max iterations, try to use the summary from the previous iteration, or return error + else if let Err(_) = answer { + if iteration_count > max_iterations { + if let Some(summary_str) = &summary_text { + let cleaned_answer = ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper( + summary_str.as_str(), + )); + return Ok(cleaned_answer); + } else { + return Err(AgentError::InferenceRecursionLimitReached(job_task.clone())); + } + } } // If not an answer, then the LLM must respond with a search/summary, so we parse them @@ -94,50 +132,56 @@ impl JobManager { agent.clone(), response_json.clone(), filled_prompt.clone(), - vec!["search".to_string(), "search_term".to_string()], - 2, + vec!["summary".to_string(), "answer".to_string(), "text".to_string()], + 3, ) .await { - Ok((search_str, new_resp_json)) => { - let summary_str = match &JobManager::advanced_extract_key_from_inference_response( + Ok((summary_str, new_resp_json)) => { + let new_search_text = match &JobManager::advanced_extract_key_from_inference_response( agent.clone(), new_resp_json.clone(), filled_prompt.clone(), - vec!["summary".to_string(), "answer".to_string()], - 4, + vec!["search".to_string(), "lookup".to_string()], + 2, ) .await { - Ok((summary, _)) => Some(summary.to_string()), + Ok((search_text, _)) => Some(search_text.to_string()), Err(_) => None, }; - (search_str.to_string(), summary_str) + // Just use summary string as search text if LLM didn't provide one to decease # of inferences + ( + new_search_text.unwrap_or(summary_str.to_string()), + summary_str.to_string(), + ) } Err(_) => { - println!("Failed qa inference chain"); - return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())); + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + &format!("Failed qa inference chain: Missing Field {}", "summary"), + ); + return Err(AgentError::InferenceJSONResponseMissingField("summary".to_string())); } }; // If the new search text is the same as the previous one, prompt the agent for a new search term let mut new_search_text = new_search_text.clone(); if Some(new_search_text.clone()) == search_text && !full_job.scope.is_empty() { - let retry_prompt = JobPromptGenerator::retry_new_search_term_prompt( - new_search_text.clone(), - summary.clone().unwrap_or_default(), - ); - let response_json = JobManager::inference_agent(agent.clone(), retry_prompt).await?; - match JobManager::direct_extract_key_inference_json_response(response_json, "search") { - Ok(search_str) => { - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Info, - &format!("QA Chain New Search Retry Term: {:?}", search_str), - ); - new_search_text = search_str; + let retry_prompt = + JobPromptGenerator::retry_new_search_term_prompt(new_search_text.clone(), summary.clone()); + let response = JobManager::inference_agent(agent.clone(), retry_prompt).await; + if let Ok(response_json) = response { + match JobManager::direct_extract_key_inference_json_response(response_json, "search") { + Ok(search_str) => { + new_search_text = search_str; + } + // If extracting fails, use summary to make the new search text likely different compared to last iteration + Err(_) => new_search_text = summary.clone(), } - Err(_) => {} + } else { + new_search_text = summary.clone(); } } @@ -152,10 +196,74 @@ impl JobManager { generator, user_profile, Some(new_search_text), - summary, + Some(summary.to_string()), iteration_count + 1, max_iterations, ) .await } } + +async fn no_json_object_retry_logic( + response: Result, + db: Arc>, + vector_fs: Arc>, + full_job: Job, + job_task: String, + agent: SerializedAgent, + execution_context: HashMap, + generator: &dyn EmbeddingGenerator, + user_profile: ShinkaiName, + summary_text: Option, + summary_node_text: Option, + iteration_count: u64, + max_iterations: u64, +) -> Result { + if let Err(e) = &response { + // If still more iterations left, then recurse to try one more time, using summary as the new search text to likely get different LLM output + if iteration_count < max_iterations { + return JobManager::start_qa_inference_chain( + db, + vector_fs, + full_job, + job_task.to_string(), + agent, + execution_context, + generator, + user_profile, + summary_text.clone(), + summary_text, + iteration_count + 1, + max_iterations, + ) + .await; + } + // Else if we're past the max iterations, return either last valid summary from previous iterations or VR summary + else { + shinkai_log( + ShinkaiLogOption::JobExecution, + ShinkaiLogLevel::Error, + &format!("Qa inference chain failure due to no parsable JSON produced: {}\nUsing summary backup to respond to user.", e), + ); + let mut summary_answer = String::new(); + // Try from previous iteration + if let Some(summary_str) = &summary_text { + summary_answer = summary_str.to_string() + } + // Else use the VR summary. We create _temp_res to have `response?` resolve to pushing the error properly + else { + let mut _temp_resp = JsonValue::Null; + match summary_node_text { + Some(text) => summary_answer = text.to_string(), + None => _temp_resp = response?, + } + } + + // Return the cleaned summary + let cleaned_answer = + ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper(summary_answer.as_str())); + return Ok(cleaned_answer); + } + } + Err(AgentError::InferenceFailed) +} diff --git a/src/agent/execution/job_execution_helpers.rs b/src/agent/execution/job_execution_helpers.rs index 54733d625..3d81b894e 100644 --- a/src/agent/execution/job_execution_helpers.rs +++ b/src/agent/execution/job_execution_helpers.rs @@ -42,7 +42,7 @@ impl JobManager { let mut current_response_json = response_json; for _ in 0..retry_attempts { for key in &potential_keys { - let new_response_json = JobManager::json_not_found_retry( + let new_response_json = internal_json_not_found_retry( agent.clone(), current_response_json.to_string(), filled_prompt.clone(), @@ -101,26 +101,14 @@ impl JobManager { }) .await; - let response = match task_response { - Ok(res) => res, - Err(e) => { - eprintln!("Task panicked with error: {:?}", e); - return Err(AgentError::InferenceFailed); - } - }; - + let response = task_response?; shinkai_log( ShinkaiLogOption::JobExecution, ShinkaiLogLevel::Debug, format!("inference_agent> response: {:?}", response).as_str(), ); - // Validates that the response is a proper JSON object, else inferences again to get the - // LLM to parse the previous response into proper JSON - let json_resp = - JobManager::_extract_json_value_from_inference_result(response, agent.clone(), filled_prompt).await?; - let cleaned_json = JobManager::convert_inference_response_to_internal_strings(json_resp); - Ok(cleaned_json) + response } /// Internal method that attempts to extract the JsonValue out of the LLM's response. If it is not proper JSON @@ -146,7 +134,7 @@ impl JobManager { } // - match JobManager::json_not_found_retry(agent.clone(), text.clone(), filled_prompt, None).await { + match internal_json_not_found_retry(agent.clone(), text.clone(), filled_prompt, None).await { Ok(json) => Ok(json), Err(e) => Err(e), } @@ -155,35 +143,6 @@ impl JobManager { } } - /// Inferences the LLM again asking it to take its previous answer and make sure it responds with a proper JSON object - /// that we can parse. json_key_to_correct allows providing a specific key that the LLM should make sure to correct. - async fn json_not_found_retry( - agent: SerializedAgent, - invalid_json_answer: String, - original_prompt: Prompt, - json_key_to_correct: Option, - ) -> Result { - let response = tokio::spawn(async move { - let agent = Agent::from_serialized_agent(agent); - let prompt = JobPromptGenerator::basic_json_retry_response_prompt( - invalid_json_answer, - original_prompt, - json_key_to_correct, - ); - agent.inference(prompt).await - }) - .await; - let response = match response { - Ok(res) => res?, - Err(e) => { - eprintln!("Task panicked with error: {:?}", e); - return Err(AgentError::InferenceFailed); - } - }; - - Ok(response) - } - /// Fetches boilerplate/relevant data required for a job to process a step pub async fn fetch_relevant_job_data( job_id: &str, @@ -293,3 +252,32 @@ fn to_dash_case(s: &str) -> String { }) .collect() } + +/// Inferences the LLM again asking it to take its previous answer and make sure it responds with a proper JSON object +/// that we can parse. json_key_to_correct allows providing a specific key that the LLM should make sure to correct. +async fn internal_json_not_found_retry( + agent: SerializedAgent, + invalid_json_answer: String, + original_prompt: Prompt, + json_key_to_correct: Option, +) -> Result { + let response = tokio::spawn(async move { + let agent = Agent::from_serialized_agent(agent); + let prompt = JobPromptGenerator::basic_json_retry_response_prompt( + invalid_json_answer, + original_prompt, + json_key_to_correct, + ); + agent.inference(prompt).await + }) + .await; + let response = match response { + Ok(res) => res?, + Err(e) => { + eprintln!("Task panicked with error: {:?}", e); + return Err(AgentError::InferenceFailed); + } + }; + + Ok(response) +} diff --git a/src/agent/execution/job_vector_search.rs b/src/agent/execution/job_vector_search.rs index d4d62e1a8..79ad275bb 100644 --- a/src/agent/execution/job_vector_search.rs +++ b/src/agent/execution/job_vector_search.rs @@ -41,7 +41,7 @@ impl JobManager { } /// Perform a vector search on all local & VectorFS-held Vector Resources specified in the JobScope. - /// If include_description is true then adds the description of the Vector Resource as an auto-included + /// If include_description is true then adds the description of the highest scored Vector Resource as an auto-included /// RetrievedNode at the front of the returned list. pub async fn job_scope_vector_search( db: Arc>, diff --git a/src/agent/file_parsing.rs b/src/agent/file_parsing.rs index 3375c5f4e..fd93c65f4 100644 --- a/src/agent/file_parsing.rs +++ b/src/agent/file_parsing.rs @@ -20,24 +20,52 @@ use std::io::Cursor; impl JobManager { /// Given a list of UnstructuredElements generates a description using the Agent's LLM + // TODO: the 2000 should be dynamic depending on the agent model capabilities pub async fn generate_description( elements: &Vec, agent: SerializedAgent, + max_node_size: u64, ) -> Result { - // TODO: the 2000 should be dynamic depending on the LLM model let prompt = ParsingHelper::process_elements_into_description_prompt(&elements, 2000); - let response_json = JobManager::inference_agent(agent.clone(), prompt.clone()).await?; - let (answer, _new_resp_json) = &JobManager::advanced_extract_key_from_inference_response( - agent.clone(), - response_json, - prompt, - vec!["summary".to_string(), "answer".to_string()], - 1, - ) - .await?; - let desc = Some(ParsingHelper::ending_stripper(answer)); - eprintln!("LLM Generated File Description: {:?}", desc); - Ok(desc.unwrap_or_else(|| "".to_string())) + + let mut extracted_answer: Option = None; + for _ in 0..5 { + let response_json = match JobManager::inference_agent(agent.clone(), prompt.clone()).await { + Ok(json) => json, + Err(e) => { + continue; // Continue to the next iteration on error + } + }; + let (answer, _new_resp_json) = match JobManager::advanced_extract_key_from_inference_response( + agent.clone(), + response_json, + prompt.clone(), + vec!["summary".to_string(), "answer".to_string()], + 1, + ) + .await + { + Ok(result) => result, + Err(e) => { + continue; // Continue to the next iteration on error + } + }; + extracted_answer = Some(answer.clone()); + break; // Exit the loop if successful + } + + if let Some(answer) = extracted_answer { + let desc = ParsingHelper::ending_stripper(&answer); + Ok(desc) + } else { + eprintln!( + "Failed to generate VR description after multiple attempts. Defaulting to text from first N nodes." + ); + + let concat_text = ParsingHelper::concatenate_elements_up_to_max_size(&elements, max_node_size as usize); + let desc = ParsingHelper::ending_stripper(&concat_text); + Ok(desc) + } } /// Processes the list of files into VRKai structs ready to be used/saved/etc. @@ -117,14 +145,9 @@ impl JobManager { ParsingHelper::parse_file_helper(file_buffer.clone(), name.clone(), unstructured_api).await?; let mut desc = String::new(); if let Some(actual_agent) = agent { - desc = Self::generate_description(&elements, actual_agent).await?; + desc = Self::generate_description(&elements, actual_agent, max_node_size).await?; } else { - desc = elements - .iter() - .take_while(|e| desc.len() + e.text.len() <= max_node_size as usize) - .map(|e| e.text.as_str()) - .collect::>() - .join(" "); + desc = ParsingHelper::concatenate_elements_up_to_max_size(&elements, max_node_size as usize); } ParsingHelper::parse_elements_into_resource( @@ -148,6 +171,19 @@ impl ParsingHelper { UnstructuredParser::generate_data_hash(content) } + /// Concatenate elements text up to a maximum size. + pub fn concatenate_elements_up_to_max_size(elements: &[UnstructuredElement], max_size: usize) -> String { + let mut desc = String::new(); + for e in elements { + if desc.len() + e.text.len() + 1 > max_size { + break; // Stop appending if adding the next element would exceed max_size + } + desc.push_str(&e.text); + desc.push('\n'); // Add a line break after each element's text + } + desc.trim_end().to_string() // Trim any trailing space before returning + } + /// Processes the file buffer through Unstructured, our hierarchical structuring algo, /// generates all embeddings, and returns a finalized BaseVectorResource. /// Note: The file name must include the extension ie. `*.pdf` diff --git a/src/agent/providers/mod.rs b/src/agent/providers/mod.rs index ca0898785..a02333fde 100644 --- a/src/agent/providers/mod.rs +++ b/src/agent/providers/mod.rs @@ -25,6 +25,7 @@ pub trait LLMProvider { /// Given an input string, parses the largest JSON object that it finds. Largest allows us to skip over /// cases where LLMs repeat your schema or post other broken/smaller json objects for whatever reason. fn extract_largest_json_object(s: &str) -> Result { + println!("909 - Starting parsing of inference response: {} ", s); match internal_extract_json_string(s) { Ok(json_str) => match serde_json::from_str(&json_str) { Ok(json_val) => Ok(json_val), @@ -151,7 +152,7 @@ fn internal_extract_json_string(s: &str) -> Result { } // Return the longest JSON string - match json_strings.into_iter().max_by_key(|s| s.len()) { + match json_strings.into_iter().max_by_key(|jstr| jstr.len()) { Some(longest_json_string) => Ok(longest_json_string), None => Err(AgentError::FailedExtractingJSONObjectFromResponse( s.to_string() + " - No JSON strings found",