-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Further large reorginization of job execution
- Loading branch information
Showing
6 changed files
with
317 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pub mod qa_inference_chain; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
use crate::agent::agent::Agent; | ||
use crate::agent::error::AgentError; | ||
use crate::agent::execution::job_prompts::JobPromptGenerator; | ||
use crate::agent::job::{Job, JobId, JobLike}; | ||
use crate::agent::job_manager::AgentManager; | ||
use async_recursion::async_recursion; | ||
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; | ||
use shinkai_vector_resources::embedding_generator::EmbeddingGenerator; | ||
use std::result::Result::Ok; | ||
use std::{collections::HashMap, sync::Arc}; | ||
use tokio::sync::Mutex; | ||
|
||
impl AgentManager { | ||
/// An inference chain for question-answer job tasks which vector searches the Vector Resources | ||
/// in the JobScope to find relevant content for the LLM to use at each step. | ||
#[async_recursion] | ||
pub async fn process_qa_inference_chain( | ||
&self, | ||
full_job: Job, | ||
job_task: String, | ||
agent: Arc<Mutex<Agent>>, | ||
execution_context: HashMap<String, String>, | ||
generator: &dyn EmbeddingGenerator, | ||
user_profile: Option<ShinkaiName>, | ||
search_text: Option<String>, | ||
prev_search_text: Option<String>, | ||
summary_text: Option<String>, | ||
iteration_count: u64, | ||
) -> Result<String, AgentError> { | ||
println!("process_qa_inference_chain> message: {:?}", job_task); | ||
|
||
// Use search_text if available (on recursion), otherwise use job_task to generate the query (on first iteration) | ||
let query_text = search_text.clone().unwrap_or(job_task.clone()); | ||
let query = generator.generate_embedding_default(&query_text).unwrap(); | ||
let ret_data_chunks = self | ||
.job_scope_vector_search(full_job.scope(), query, 20, &user_profile.clone().unwrap()) | ||
.await?; | ||
|
||
// Use the default prompt if not reached final iteration count, else use final prompt | ||
let filled_prompt = if iteration_count < 5 { | ||
JobPromptGenerator::response_prompt_with_vector_search( | ||
job_task.clone(), | ||
ret_data_chunks, | ||
summary_text, | ||
prev_search_text, | ||
) | ||
} else { | ||
JobPromptGenerator::response_prompt_with_vector_search_final( | ||
job_task.clone(), | ||
ret_data_chunks, | ||
summary_text, | ||
) | ||
}; | ||
|
||
// Inference the agent's LLM with the prompt | ||
let response_json = self.inference_agent(agent.clone(), filled_prompt).await?; | ||
|
||
// If it has an answer, the chain is finished and so just return the answer response as a String | ||
if let Some(answer) = response_json.get("answer") { | ||
let answer_str = answer | ||
.as_str() | ||
.ok_or_else(|| AgentError::InferenceJSONResponseMissingField("answer".to_string()))?; | ||
return Ok(answer_str.to_string()); | ||
} | ||
// If iteration_count is 5 and we still don't have an answer, return an error | ||
else if iteration_count >= 5 { | ||
return Err(AgentError::InferenceRecursionLimitReached(job_task.clone())); | ||
} | ||
|
||
// If not an answer, then the LLM must respond with a search/summary, so we parse them | ||
// to use for the next recursive call | ||
let (new_search_text, summary) = match response_json.get("search") { | ||
Some(search) => { | ||
let search_str = search | ||
.as_str() | ||
.ok_or_else(|| AgentError::InferenceJSONResponseMissingField("search".to_string()))?; | ||
let summary_str = response_json | ||
.get("summary") | ||
.and_then(|s| s.as_str()) | ||
.map(|s| s.to_string()); | ||
(search_str, summary_str) | ||
} | ||
None => return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())), | ||
}; | ||
|
||
// Recurse with the new search/summary text and increment iteration_count | ||
self.process_qa_inference_chain( | ||
full_job, | ||
job_task.to_string(), | ||
agent, | ||
execution_context, | ||
generator, | ||
user_profile, | ||
Some(new_search_text.to_string()), | ||
search_text, | ||
summary, | ||
iteration_count + 1, | ||
) | ||
.await | ||
} | ||
} |
Oops, something went wrong.