Skip to content

Commit

Permalink
Merge pull request #110 from dcSpark/nico/job_parallelization
Browse files Browse the repository at this point in the history
Nico/job parallelization
  • Loading branch information
robkorn authored Oct 16, 2023
2 parents 508d6bd + 24c8684 commit 15ee759
Show file tree
Hide file tree
Showing 18 changed files with 1,036 additions and 448 deletions.
20 changes: 4 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub struct JobCreationInfo {
pub scope: JobScope,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct JobMessage {
// TODO: scope div modifications?
pub job_id: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,13 @@ impl UnstructuredParser {

// Extract keywords from the elements
let keywords = UnstructuredParser::extract_keywords(&text_groups, 50);
eprintln!("Keywords: {:?}", keywords);

// Set the resource embedding, using the keywords + name + desc + source
doc.update_resource_embedding(generator, keywords)?;

eprintln!("Processing doc: `{}`", doc.name());

// Add each text group as either Vector Resource DataChunks,
// or data-holding DataChunks depending on if each has any sub-groups
for grouped_text in &text_groups {
Expand Down Expand Up @@ -151,6 +154,7 @@ impl UnstructuredParser {
}
}

eprintln!("Finished processing doc: `{}`", doc.name());
Ok(BaseVectorResource::Document(doc))
}

Expand Down
9 changes: 0 additions & 9 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ use tokio::sync::{mpsc, Mutex};
pub struct Agent {
pub id: String,
pub full_identity_name: ShinkaiName,
pub job_manager_sender: mpsc::Sender<(Vec<JobPreMessage>, String)>,
pub agent_receiver: Arc<Mutex<mpsc::Receiver<String>>>,
pub client: Client,
pub perform_locally: bool, // Todo: Remove as not used anymore
pub external_url: Option<String>, // external API URL
Expand All @@ -33,7 +31,6 @@ impl Agent {
pub fn new(
id: String,
full_identity_name: ShinkaiName,
job_manager_sender: mpsc::Sender<(Vec<JobPreMessage>, String)>,
perform_locally: bool,
external_url: Option<String>,
api_key: Option<String>,
Expand All @@ -43,13 +40,9 @@ impl Agent {
allowed_message_senders: Vec<String>,
) -> Self {
let client = Client::new();
let (_, agent_receiver) = mpsc::channel(1); // TODO: I think we can remove this altogether
let agent_receiver = Arc::new(Mutex::new(agent_receiver)); // wrap the receiver
Self {
id,
full_identity_name,
job_manager_sender,
agent_receiver,
client,
perform_locally,
external_url,
Expand Down Expand Up @@ -101,12 +94,10 @@ impl Agent {
impl Agent {
pub fn from_serialized_agent(
serialized_agent: SerializedAgent,
sender: mpsc::Sender<(Vec<JobPreMessage>, String)>,
) -> Self {
Self::new(
serialized_agent.id,
serialized_agent.full_identity_name,
sender,
serialized_agent.perform_locally,
serialized_agent.external_url,
serialized_agent.api_key,
Expand Down
2 changes: 2 additions & 0 deletions src/agent/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum AgentError {
TaskJoinError(String),
InferenceRecursionLimitReached(String),
TokenizationError(String),
JobDequeueFailed(String)
}

impl fmt::Display for AgentError {
Expand Down Expand Up @@ -78,6 +79,7 @@ impl fmt::Display for AgentError {
AgentError::TaskJoinError(s) => write!(f, "Task join error: {}", s),
AgentError::InferenceRecursionLimitReached(s) => write!(f, "Inferencing the LLM has reached too many iterations of recursion with no progess, and thus has been stopped for this job_task: {}", s),
AgentError::TokenizationError(s) => write!(f, "Tokenization error: {}", s),
AgentError::JobDequeueFailed(s) => write!(f, "Job dequeue failed: {}", s),

}
}
Expand Down
38 changes: 20 additions & 18 deletions src/agent/execution/chains/inference_chain_router.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::agent::agent::Agent;
use crate::agent::error::AgentError;
use crate::agent::job::Job;
use crate::agent::job_manager::AgentManager;
use crate::agent::job_manager::JobManager;
use crate::db::ShinkaiDB;
use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::JobMessage;
use shinkai_vector_resources::embedding_generator::EmbeddingGenerator;
Expand All @@ -15,13 +17,13 @@ pub enum InferenceChain {
CodingChain,
}

impl AgentManager {
impl JobManager {
/// Chooses an inference chain based on the job message (using the agent's LLM)
/// and then starts using the chosen chain.
/// Returns the final String result from the inferencing, and a new execution context.
pub async fn inference_chain_router(
&self,
agent_found: Option<Arc<Mutex<Agent>>>,
db: Arc<Mutex<ShinkaiDB>>,
agent_found: Option<SerializedAgent>,
full_job: Job,
job_message: JobMessage,
prev_execution_context: HashMap<String, String>,
Expand All @@ -37,20 +39,20 @@ impl AgentManager {
match chosen_chain {
InferenceChain::QAChain => {
if let Some(agent) = agent_found {
inference_response_content = self
.start_qa_inference_chain(
full_job,
job_message.content.clone(),
agent,
prev_execution_context,
generator,
user_profile,
None,
None,
0,
5,
)
.await?;
inference_response_content = JobManager::start_qa_inference_chain(
db,
full_job,
job_message.content.clone(),
agent,
prev_execution_context,
generator,
user_profile,
None,
None,
0,
5,
)
.await?;
new_execution_context
.insert("previous_step_response".to_string(), inference_response_content.clone());
} else {
Expand Down
50 changes: 27 additions & 23 deletions src/agent/execution/chains/qa_inference_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,26 @@ use crate::agent::error::AgentError;
use crate::agent::execution::job_prompts::JobPromptGenerator;
use crate::agent::file_parsing::ParsingHelper;
use crate::agent::job::{Job, JobId, JobLike};
use crate::agent::job_manager::AgentManager;
use crate::agent::job_manager::JobManager;
use crate::db::ShinkaiDB;
use crate::resources::bert_cpp::BertCPPProcess;
use async_recursion::async_recursion;
use shinkai_message_primitives::schemas::agents::serialized_agent::SerializedAgent;
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_vector_resources::embedding_generator::EmbeddingGenerator;
use std::result::Result::Ok;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;

impl AgentManager {
impl JobManager {
/// An inference chain for question-answer job tasks which vector searches the Vector Resources
/// in the JobScope to find relevant content for the LLM to use at each step.
#[async_recursion]
pub async fn start_qa_inference_chain(
&self,
db: Arc<Mutex<ShinkaiDB>>,
full_job: Job,
job_task: String,
agent: Arc<Mutex<Agent>>,
agent: SerializedAgent,
execution_context: HashMap<String, String>,
generator: &dyn EmbeddingGenerator,
user_profile: Option<ShinkaiName>,
Expand All @@ -33,9 +36,9 @@ impl AgentManager {
// Use search_text if available (on recursion), otherwise use job_task to generate the query (on first iteration)
let query_text = search_text.clone().unwrap_or(job_task.clone());
let query = generator.generate_embedding_default(&query_text).unwrap();
let ret_data_chunks = self
.job_scope_vector_search(full_job.scope(), query, 20, &user_profile.clone().unwrap(), true)
.await?;
let ret_data_chunks =
JobManager::job_scope_vector_search(db.clone(), full_job.scope(), query, 20, &user_profile.clone().unwrap(), true)
.await?;

// Use the default prompt if not reached final iteration count, else use final prompt
let filled_prompt = if iteration_count < max_iterations {
Expand All @@ -62,8 +65,8 @@ impl AgentManager {

// Inference the agent's LLM with the prompt. If it has an answer, the chain
// is finished and so just return the answer response as a cleaned String
let response_json = self.inference_agent(agent.clone(), filled_prompt).await?;
if let Ok(answer_str) = self.extract_inference_json_response(response_json.clone(), "answer") {
let response_json = JobManager::inference_agent(agent.clone(), filled_prompt).await?;
if let Ok(answer_str) = JobManager::extract_inference_json_response(response_json.clone(), "answer") {
let cleaned_answer = ParsingHelper::ending_stripper(&answer_str);
println!("QA Chain Final Answer: {:?}", cleaned_answer);
return Ok(cleaned_answer);
Expand All @@ -75,26 +78,26 @@ impl AgentManager {

// If not an answer, then the LLM must respond with a search/summary, so we parse them
// to use for the next recursive call
let (mut new_search_text, summary) = match self.extract_inference_json_response(response_json.clone(), "search")
{
Ok(search_str) => {
let summary_str = response_json
.get("summary")
.and_then(|s| s.as_str())
.map(|s| ParsingHelper::ending_stripper(s));
(search_str, summary_str)
}
Err(_) => return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())),
};
let (mut new_search_text, summary) =
match JobManager::extract_inference_json_response(response_json.clone(), "search") {
Ok(search_str) => {
let summary_str = response_json
.get("summary")
.and_then(|s| s.as_str())
.map(|s| ParsingHelper::ending_stripper(s));
(search_str, summary_str)
}
Err(_) => return Err(AgentError::InferenceJSONResponseMissingField("search".to_string())),
};

// If the new search text is the same as the previous one, prompt the agent for a new search term
if Some(new_search_text.clone()) == search_text {
let retry_prompt = JobPromptGenerator::retry_new_search_term_prompt(
new_search_text.clone(),
summary.clone().unwrap_or_default(),
);
let response_json = self.inference_agent(agent.clone(), retry_prompt).await?;
match self.extract_inference_json_response(response_json, "search") {
let response_json = JobManager::inference_agent(agent.clone(), retry_prompt).await?;
match JobManager::extract_inference_json_response(response_json, "search") {
Ok(search_str) => {
println!("QA Chain New Search Retry Term: {:?}", search_str);
new_search_text = search_str;
Expand All @@ -104,7 +107,8 @@ impl AgentManager {
}

// Recurse with the new search/summary text and increment iteration_count
self.start_qa_inference_chain(
JobManager::start_qa_inference_chain(
db,
full_job,
job_task.to_string(),
agent,
Expand Down
4 changes: 2 additions & 2 deletions src/agent/execution/chains/tool_execution_chain.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::agent::job_manager::AgentManager;
use crate::agent::job_manager::JobManager;

impl AgentManager {
impl JobManager {
pub fn start_tool_execution_inference_chain(&self) -> () {
self.analysis_phase();

Expand Down
Loading

0 comments on commit 15ee759

Please sign in to comment.