Skip to content

Commit

Permalink
Merge pull request #268 from dcSpark/rob/inference-cleanup
Browse files Browse the repository at this point in the history
Further Inferencing Robustness Improvements
  • Loading branch information
nicarq authored Mar 7, 2024
2 parents f377090 + 08d7ee5 commit c38272d
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
[package]
name = "shinkai_node"
version = "0.5.2"
version = "0.5.3"
edition = "2018"
authors = ["Nico Arqueros <[email protected]>"]

[features]
default = ["telemetry"]
Expand Down
13 changes: 1 addition & 12 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use super::execution::job_prompts::{Prompt, SubPromptType};
use super::providers::LLMProvider;
use super::{error::AgentError, job_manager::JobManager};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use reqwest::Client;
use serde_json::{Map, Value as JsonValue};
use shinkai_message_primitives::{
Expand Down Expand Up @@ -86,19 +84,15 @@ impl Agent {
let mut response = self.internal_inference_matching_model(prompt.clone()).await;
let mut attempts = 0;

let mut rng = StdRng::from_entropy(); // This uses the system's source of entropy to seed the RNG
let random_number: i32 = rng.gen();

let mut new_prompt = prompt.clone();
while let Err(err) = &response {
if attempts >= 3 {
break;
}
attempts += 1;

// If serde failed parsing the json string, then use advanced rertrying
// If serde failed parsing the json string, then use advanced retrying
if let AgentError::FailedSerdeParsingJSONString(response_json, serde_error) = err {
println!("101 - {} - Failed parsing json of: {}", random_number, response_json);
new_prompt.add_content(response_json.to_string(), SubPromptType::Assistant, 100);
new_prompt.add_content(
format!(
Expand All @@ -118,11 +112,6 @@ impl Agent {

let cleaned_json = JobManager::convert_inference_response_to_internal_strings(response?);

println!(
"105 - {} - Succeeded parsing json of: {}",
random_number,
cleaned_json.to_string()
);
Ok(cleaned_json)
}

Expand Down
182 changes: 145 additions & 37 deletions src/agent/execution/chains/qa_inference_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::agent::job_manager::JobManager;
use crate::db::ShinkaiDB;
use crate::vector_fs::vector_fs::VectorFS;
use async_recursion::async_recursion;
use serde_json::Value as JsonValue;
use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogLevel, ShinkaiLogOption};
Expand Down Expand Up @@ -54,38 +55,75 @@ impl JobManager {
true,
)
.await?;
// Text from the first node, which is the summary of the most similar VR
let summary_node_text = ret_nodes
.get(0)
.and_then(|node| node.node.get_text_content().ok())
.map(|text| text.to_string());

// Use the default prompt if not reached final iteration count, else use final prompt
let filled_prompt = if iteration_count < max_iterations && !full_job.scope.is_empty() {
let is_not_final = iteration_count < max_iterations && !full_job.scope.is_empty();
let filled_prompt = if is_not_final {
JobPromptGenerator::response_prompt_with_vector_search(
job_task.clone(),
ret_nodes,
summary_text,
summary_text.clone(),
Some(query_text),
Some(full_job.step_history.clone()),
)
} else {
JobPromptGenerator::response_prompt_with_vector_search_final(
job_task.clone(),
ret_nodes,
summary_text,
summary_text.clone(),
Some(full_job.step_history.clone()),
)
};

// Inference the agent's LLM with the prompt. If it has an answer, the chain
// is finished and so just return the answer response as a cleaned String
let response_json = JobManager::inference_agent(agent.clone(), filled_prompt.clone()).await?;
if let Ok(answer_str) = JobManager::direct_extract_key_inference_json_response(response_json.clone(), "answer")
{
// Inference the agent's LLM with the prompt
let response = JobManager::inference_agent(agent.clone(), filled_prompt.clone()).await;
// Check if it failed to produce a proper json object at all, and if so go through more advanced retry logic
if response.is_err() {
return no_json_object_retry_logic(
response,
db,
vector_fs,
full_job,
job_task,
agent,
execution_context,
generator,
user_profile,
summary_text,
summary_node_text, // This needs to be defined or passed appropriately
iteration_count,
max_iterations,
)
.await;
}

// Extract the JSON from the inference response Result and proceed forward
let response_json = response?;
let answer = JobManager::direct_extract_key_inference_json_response(response_json.clone(), "answer");

// If it has an answer, the chain is finished and so just return the answer response as a cleaned String
if let Ok(answer_str) = answer {
let cleaned_answer =
ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper(&answer_str));
// println!("QA Chain Final Answer: {:?}", cleaned_answer);
return Ok(cleaned_answer);
}
// If iteration_count is > max_iterations and we still don't have an answer, return an error
else if iteration_count > max_iterations {
return Err(AgentError::InferenceRecursionLimitReached(job_task.clone()));
// If it errored and past max iterations, try to use the summary from the previous iteration, or return error
else if let Err(_) = answer {
if iteration_count > max_iterations {
if let Some(summary_str) = &summary_text {
let cleaned_answer = ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper(
summary_str.as_str(),
));
return Ok(cleaned_answer);
} else {
return Err(AgentError::InferenceRecursionLimitReached(job_task.clone()));
}
}
}

// If not an answer, then the LLM must respond with a search/summary, so we parse them
Expand All @@ -94,50 +132,56 @@ impl JobManager {
agent.clone(),
response_json.clone(),
filled_prompt.clone(),
vec!["search".to_string(), "search_term".to_string()],
2,
vec!["summary".to_string(), "answer".to_string(), "text".to_string()],
3,
)
.await
{
Ok((search_str, new_resp_json)) => {
let summary_str = match &JobManager::advanced_extract_key_from_inference_response(
Ok((summary_str, new_resp_json)) => {
let new_search_text = match &JobManager::advanced_extract_key_from_inference_response(
agent.clone(),
new_resp_json.clone(),
filled_prompt.clone(),
vec!["summary".to_string(), "answer".to_string()],
4,
vec!["search".to_string(), "lookup".to_string()],
2,
)
.await
{
Ok((summary, _)) => Some(summary.to_string()),
Ok((search_text, _)) => Some(search_text.to_string()),
Err(_) => None,
};
(search_str.to_string(), summary_str)
// Just use summary string as search text if LLM didn't provide one to decease # of inferences
(
new_search_text.unwrap_or(summary_str.to_string()),
summary_str.to_string(),
)
}
Err(_) => {
println!("Failed qa inference chain");
return Err(AgentError::InferenceJSONResponseMissingField("search".to_string()));
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Failed qa inference chain: Missing Field {}", "summary"),
);
return Err(AgentError::InferenceJSONResponseMissingField("summary".to_string()));
}
};

// If the new search text is the same as the previous one, prompt the agent for a new search term
let mut new_search_text = new_search_text.clone();
if Some(new_search_text.clone()) == search_text && !full_job.scope.is_empty() {
let retry_prompt = JobPromptGenerator::retry_new_search_term_prompt(
new_search_text.clone(),
summary.clone().unwrap_or_default(),
);
let response_json = JobManager::inference_agent(agent.clone(), retry_prompt).await?;
match JobManager::direct_extract_key_inference_json_response(response_json, "search") {
Ok(search_str) => {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Info,
&format!("QA Chain New Search Retry Term: {:?}", search_str),
);
new_search_text = search_str;
let retry_prompt =
JobPromptGenerator::retry_new_search_term_prompt(new_search_text.clone(), summary.clone());
let response = JobManager::inference_agent(agent.clone(), retry_prompt).await;
if let Ok(response_json) = response {
match JobManager::direct_extract_key_inference_json_response(response_json, "search") {
Ok(search_str) => {
new_search_text = search_str;
}
// If extracting fails, use summary to make the new search text likely different compared to last iteration
Err(_) => new_search_text = summary.clone(),
}
Err(_) => {}
} else {
new_search_text = summary.clone();
}
}

Expand All @@ -152,10 +196,74 @@ impl JobManager {
generator,
user_profile,
Some(new_search_text),
summary,
Some(summary.to_string()),
iteration_count + 1,
max_iterations,
)
.await
}
}

async fn no_json_object_retry_logic(
response: Result<JsonValue, AgentError>,
db: Arc<Mutex<ShinkaiDB>>,
vector_fs: Arc<Mutex<VectorFS>>,
full_job: Job,
job_task: String,
agent: SerializedAgent,
execution_context: HashMap<String, String>,
generator: &dyn EmbeddingGenerator,
user_profile: ShinkaiName,
summary_text: Option<String>,
summary_node_text: Option<String>,
iteration_count: u64,
max_iterations: u64,
) -> Result<String, AgentError> {
if let Err(e) = &response {
// If still more iterations left, then recurse to try one more time, using summary as the new search text to likely get different LLM output
if iteration_count < max_iterations {
return JobManager::start_qa_inference_chain(
db,
vector_fs,
full_job,
job_task.to_string(),
agent,
execution_context,
generator,
user_profile,
summary_text.clone(),
summary_text,
iteration_count + 1,
max_iterations,
)
.await;
}
// Else if we're past the max iterations, return either last valid summary from previous iterations or VR summary
else {
shinkai_log(
ShinkaiLogOption::JobExecution,
ShinkaiLogLevel::Error,
&format!("Qa inference chain failure due to no parsable JSON produced: {}\nUsing summary backup to respond to user.", e),
);
let mut summary_answer = String::new();
// Try from previous iteration
if let Some(summary_str) = &summary_text {
summary_answer = summary_str.to_string()
}
// Else use the VR summary. We create _temp_res to have `response?` resolve to pushing the error properly
else {
let mut _temp_resp = JsonValue::Null;
match summary_node_text {
Some(text) => summary_answer = text.to_string(),
None => _temp_resp = response?,
}
}

// Return the cleaned summary
let cleaned_answer =
ParsingHelper::flatten_to_content_if_json(&ParsingHelper::ending_stripper(summary_answer.as_str()));
return Ok(cleaned_answer);
}
}
Err(AgentError::InferenceFailed)
}
Loading

0 comments on commit c38272d

Please sign in to comment.