diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs index ace57fd64..c8df01aca 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_inference_chain.rs @@ -71,7 +71,6 @@ impl InferenceChain for GenericInferenceChain { self.context.message_hash_id.clone(), self.context.image_files.clone(), self.context.llm_provider.clone(), - self.context.execution_context.clone(), self.context.generator.clone(), self.context.user_profile.clone(), self.context.max_iterations, @@ -110,7 +109,6 @@ impl GenericInferenceChain { message_hash_id: Option, image_files: HashMap, llm_provider: ProviderOrAgent, - execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, max_iterations: u64, @@ -324,7 +322,7 @@ impl GenericInferenceChain { image_files.clone(), ret_nodes.clone(), summary_node_text.clone(), - Some(full_job.step_history.clone()), + full_job.prompts.clone(), tools.clone(), None, ); @@ -378,7 +376,6 @@ impl GenericInferenceChain { message_hash_id.clone(), image_files.clone(), llm_provider.clone(), - execution_context.clone(), generator.clone(), user_profile.clone(), max_iterations, @@ -394,7 +391,10 @@ impl GenericInferenceChain { // 6) Call workflow or tooling // Find the ShinkaiTool that has a tool with the function name - let shinkai_tool = tools.iter().find(|tool| tool.name() == function_call.name || tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default()); + let shinkai_tool = tools.iter().find(|tool| { + tool.name() == function_call.name + || tool.tool_router_key() == function_call.tool_router_key.clone().unwrap_or_default() + }); if shinkai_tool.is_none() { eprintln!("Function not found: {}", function_call.name); return Err(LLMProviderError::FunctionNotFound(function_call.name.clone())); @@ -439,7 +439,7 @@ impl GenericInferenceChain { image_files.clone(), ret_nodes.clone(), summary_node_text.clone(), - Some(full_job.step_history.clone()), + full_job.prompts.clone(), tools.clone(), Some(function_response), ); @@ -451,7 +451,6 @@ impl GenericInferenceChain { response.response_string, response.tps.map(|tps| tps.to_string()), answer_duration_ms, - execution_context.clone(), Some(tool_calls_history.clone()), ); diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs index 72c68df67..9023205ae 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs @@ -20,7 +20,7 @@ impl JobPromptGenerator { image_files: HashMap, ret_nodes: Vec, _summary_text: Option, - job_step_history: Option>, + job_prompts: Vec, tools: Vec, function_call: Option, ) -> Prompt { @@ -36,9 +36,12 @@ impl JobPromptGenerator { let has_ret_nodes = !ret_nodes.is_empty(); // Add previous messages - // TODO: this should be full messages with assets and not just strings - if let Some(step_history) = job_step_history { - prompt.add_step_history(step_history, 97); + if !job_prompts.is_empty() { + let sub_prompts = job_prompts + .into_iter() + .flat_map(|prompt| prompt.sub_prompts.clone()) + .collect(); + prompt.add_sub_prompts_with_new_priority(sub_prompts, 97); } // Add tools if any. Decrease priority every 2 tools diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs index 5988a8205..1ea10e86d 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_router.rs @@ -24,7 +24,7 @@ use tokio::sync::{Mutex, RwLock}; impl JobManager { /// Chooses an inference chain based on the job message (using the agent's LLM) /// and then starts using the chosen chain. - /// Returns the final String result from the inferencing, and a new execution context. + /// Returns the final String result from the inferencing. #[allow(clippy::too_many_arguments)] pub async fn inference_chain_router( db: Arc>, @@ -34,7 +34,6 @@ impl JobManager { job_message: JobMessage, message_hash_id: Option, image_files: HashMap, - prev_execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, ws_manager_trait: Option>>, @@ -75,7 +74,6 @@ impl JobManager { message_hash_id, image_files, llm_provider, - prev_execution_context, generator, user_profile, 3, diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs index 47c5c8639..266bdca56 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/inference_chain_trait.rs @@ -65,7 +65,6 @@ pub trait InferenceChainContextTrait: Send + Sync { fn message_hash_id(&self) -> Option; fn image_files(&self) -> &HashMap; fn agent(&self) -> &ProviderOrAgent; - fn execution_context(&self) -> &HashMap; fn generator(&self) -> &RemoteEmbeddingGenerator; fn user_profile(&self) -> &ShinkaiName; fn max_iterations(&self) -> u64; @@ -134,10 +133,6 @@ impl InferenceChainContextTrait for InferenceChainContext { &self.llm_provider } - fn execution_context(&self) -> &HashMap { - &self.execution_context - } - fn generator(&self) -> &RemoteEmbeddingGenerator { &self.generator } @@ -207,8 +202,6 @@ pub struct InferenceChainContext { pub message_hash_id: Option, pub image_files: HashMap, pub llm_provider: ProviderOrAgent, - /// Job's execution context, used to store potentially relevant data across job steps. - pub execution_context: HashMap, pub generator: RemoteEmbeddingGenerator, pub user_profile: ShinkaiName, pub max_iterations: u64, @@ -235,7 +228,6 @@ impl InferenceChainContext { message_hash_id: Option, image_files: HashMap, llm_provider: ProviderOrAgent, - execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, max_iterations: u64, @@ -257,7 +249,6 @@ impl InferenceChainContext { message_hash_id, image_files, llm_provider, - execution_context, generator, user_profile, max_iterations, @@ -296,7 +287,6 @@ impl fmt::Debug for InferenceChainContext { .field("message_hash_id", &self.message_hash_id) .field("image_files", &self.image_files.len()) .field("llm_provider", &self.llm_provider) - .field("execution_context", &self.execution_context) .field("generator", &self.generator) .field("user_profile", &self.user_profile) .field("max_iterations", &self.max_iterations) @@ -319,15 +309,13 @@ pub struct InferenceChainResult { pub response: String, pub tps: Option, pub answer_duration: Option, - pub new_job_execution_context: HashMap, pub tool_calls: Option>, } impl InferenceChainResult { - pub fn new(response: String, new_job_execution_context: HashMap) -> Self { + pub fn new(response: String) -> Self { Self { response, - new_job_execution_context, tps: None, answer_duration: None, tool_calls: None, @@ -338,26 +326,16 @@ impl InferenceChainResult { response: String, tps: Option, answer_duration_ms: Option, - new_job_execution_context: HashMap, tool_calls: Option>, ) -> Self { Self { response, tps, answer_duration: answer_duration_ms, - new_job_execution_context, tool_calls, } } - pub fn new_empty_execution_context(response: String) -> Self { - Self::new(response, HashMap::new()) - } - - pub fn new_empty() -> Self { - Self::new_empty_execution_context(String::new()) - } - pub fn tool_calls_metadata(&self) -> Option> { self.tool_calls .as_ref() @@ -452,10 +430,6 @@ impl InferenceChainContextTrait for Box { (**self).agent() } - fn execution_context(&self) -> &HashMap { - (**self).execution_context() - } - fn generator(&self) -> &RemoteEmbeddingGenerator { (**self).generator() } @@ -517,7 +491,6 @@ impl InferenceChainContextTrait for Box { pub struct MockInferenceChainContext { pub user_message: ParsedUserMessage, pub image_files: HashMap, - pub execution_context: HashMap, pub user_profile: ShinkaiName, pub max_iterations: u64, pub iteration_count: u64, @@ -535,7 +508,6 @@ impl MockInferenceChainContext { #[allow(dead_code)] pub fn new( user_message: ParsedUserMessage, - execution_context: HashMap, user_profile: ShinkaiName, max_iterations: u64, iteration_count: u64, @@ -550,7 +522,6 @@ impl MockInferenceChainContext { Self { user_message, image_files: HashMap::new(), - execution_context, user_profile, max_iterations, iteration_count, @@ -575,7 +546,6 @@ impl Default for MockInferenceChainContext { Self { user_message, image_files: HashMap::new(), - execution_context: HashMap::new(), user_profile, max_iterations: 10, iteration_count: 0, @@ -635,10 +605,6 @@ impl InferenceChainContextTrait for MockInferenceChainContext { unimplemented!() } - fn execution_context(&self) -> &HashMap { - &self.execution_context - } - fn generator(&self) -> &RemoteEmbeddingGenerator { unimplemented!() } @@ -701,7 +667,6 @@ impl Clone for MockInferenceChainContext { Self { user_message: self.user_message.clone(), image_files: self.image_files.clone(), - execution_context: self.execution_context.clone(), user_profile: self.user_profile.clone(), max_iterations: self.max_iterations, iteration_count: self.iteration_count, diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs index b2ea59872..d82cba171 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/sheet_ui_chain/sheet_ui_inference_chain.rs @@ -66,7 +66,6 @@ impl InferenceChain for SheetUIInferenceChain { self.context.message_hash_id.clone(), self.context.image_files.clone(), self.context.llm_provider.clone(), - self.context.execution_context.clone(), self.context.generator.clone(), self.context.user_profile.clone(), self.context.max_iterations, @@ -81,8 +80,7 @@ impl InferenceChain for SheetUIInferenceChain { self.context.llm_stopper.clone(), ) .await?; - let job_execution_context = self.context.execution_context.clone(); - Ok(InferenceChainResult::new(response, job_execution_context)) + Ok(InferenceChainResult::new(response)) } } @@ -110,7 +108,6 @@ impl SheetUIInferenceChain { message_hash_id: Option, image_files: HashMap, llm_provider: ProviderOrAgent, - execution_context: HashMap, generator: RemoteEmbeddingGenerator, user_profile: ShinkaiName, max_iterations: u64, @@ -280,7 +277,7 @@ impl SheetUIInferenceChain { image_files.clone(), ret_nodes.clone(), summary_node_text.clone(), - Some(full_job.step_history.clone()), + full_job.prompts.clone(), tools.clone(), None, ); @@ -380,7 +377,6 @@ impl SheetUIInferenceChain { message_hash_id.clone(), image_files.clone(), llm_provider.clone(), - execution_context.clone(), generator.clone(), user_profile.clone(), max_iterations, @@ -416,7 +412,7 @@ impl SheetUIInferenceChain { image_files.clone(), ret_nodes.clone(), summary_node_text.clone(), - Some(full_job.step_history.clone()), + full_job.prompts.clone(), tools.clone(), Some(function_response), ); diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs index 27032d3c8..4764b9aad 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/job_execution_core.rs @@ -273,14 +273,6 @@ impl JobManager { ShinkaiLogLevel::Debug, &format!("Retrieved {} image files", image_files.len()), ); - - // Setup initial data to get ready to call a specific inference chain - let prev_execution_context = full_job.execution_context.clone(); - shinkai_log( - ShinkaiLogOption::JobExecution, - ShinkaiLogLevel::Debug, - &format!("Prev Execution Context: {:?}", prev_execution_context), - ); let start = Instant::now(); // Call the inference chain router to choose which chain to use, and call it @@ -292,7 +284,6 @@ impl JobManager { job_message.clone(), message_hash_id, image_files.clone(), - prev_execution_context, generator, user_profile.clone(), ws_manager.clone(), @@ -305,7 +296,6 @@ impl JobManager { ) .await?; let inference_response_content = inference_response.response.clone(); - let new_execution_context = inference_response.new_job_execution_context.clone(); let duration = start.elapsed(); shinkai_log( @@ -340,7 +330,7 @@ impl JobManager { ); // Save response data to DB - db.write().await.add_step_history( + db.write().await.add_job_prompt( job_message.job_id.clone(), job_message.content, Some(image_files), @@ -352,9 +342,6 @@ impl JobManager { .await .add_message_to_job_inbox(&job_message.job_id.clone(), &shinkai_message, None, ws_manager) .await?; - db.write() - .await - .set_job_execution_context(job_message.job_id.clone(), new_execution_context, None)?; // Check for callbacks and add them to the JobManagerQueue if required if let Some(callback) = &job_message.callback { @@ -466,7 +453,6 @@ impl JobManager { job_message.clone(), message_hash_id, empty_files, - HashMap::new(), // Assuming prev_execution_context is an empty HashMap generator, user_profile.clone(), ws_manager.clone(), diff --git a/shinkai-bin/shinkai-node/tests/it/db_job_tests.rs b/shinkai-bin/shinkai-node/tests/it/db_job_tests.rs index f343e77a1..987a6daaf 100644 --- a/shinkai-bin/shinkai-node/tests/it/db_job_tests.rs +++ b/shinkai-bin/shinkai-node/tests/it/db_job_tests.rs @@ -214,7 +214,7 @@ mod tests { } #[tokio::test] - async fn test_update_step_history() { + async fn test_update_job_prompts() { let job_id = "test_job".to_string(); let db = setup_test_db(); let shinkai_db = Arc::new(RwLock::new(db)); @@ -248,11 +248,11 @@ mod tests { .await .unwrap(); - // Update step history + // Update prompts shinkai_db .write() .await - .add_step_history( + .add_job_prompt( job_id.clone(), "What is 10 + 25".to_string(), None, @@ -265,7 +265,7 @@ mod tests { shinkai_db .write() .await - .add_step_history( + .add_job_prompt( job_id.clone(), "2) What is 10 + 25".to_string(), None, @@ -275,9 +275,9 @@ mod tests { ) .unwrap(); - // Retrieve the job and check that step history is updated + // Retrieve the job and check that prompt is updated let job = shinkai_db.read().await.get_job(&job_id.clone()).unwrap(); - assert_eq!(job.step_history.len(), 2); + assert_eq!(job.prompts.len(), 2); } #[tokio::test] @@ -495,7 +495,7 @@ mod tests { } #[tokio::test] - async fn test_job_inbox_tree_structure_with_step_history_and_execution_context() { + async fn test_job_inbox_tree_structure_with_prompts() { let job_id = "job_test".to_string(); let agent_id = "agent_test".to_string(); let scope = JobScope::new_default(); @@ -550,12 +550,12 @@ mod tests { .add_message_to_job_inbox(&job_id.clone(), &shinkai_message, parent_hash.clone(), None) .await; - // Add a step history + // Add a prompt let result = format!("Result {}", i); shinkai_db .write() .await - .add_step_history( + .add_job_prompt( job_id.clone(), format!("Step {} Level {}", i, current_level), None, @@ -568,15 +568,6 @@ mod tests { // Add the result to the results vector results.push(result); - // Set job execution context - let mut execution_context = HashMap::new(); - execution_context.insert("context".to_string(), results.join(", ")); - shinkai_db - .write() - .await - .set_job_execution_context(job_id.clone(), execution_context, None) - .unwrap(); - // Update the parent message according to the tree structure if i == 1 { parent_message_hash = Some(shinkai_message.calculate_message_hash_for_pagination()); @@ -628,45 +619,36 @@ mod tests { let job_message_4: JobMessage = serde_json::from_str(&message_content_4).unwrap(); assert_eq!(job_message_4.content, "Hello World 4".to_string()); - // Check the step history and execution context + // Check the prompts let job = shinkai_db.read().await.get_job(&job_id.clone()).unwrap(); - eprintln!("job execution context: {:?}", job.execution_context); - - // Check the execution context - assert_eq!( - job.execution_context.get("context").unwrap(), - "Result 1, Result 2, Result 4" - ); - - // Check the step history - let step1 = &job.step_history[0]; - let step2 = &job.step_history[1]; - let step4 = &job.step_history[2]; + let prompt1 = &job.prompts[0]; + let prompt2 = &job.prompts[1]; + let prompt4 = &job.prompts[2]; assert_eq!( - step1.step_revisions[0].sub_prompts[0], + prompt1.sub_prompts[0], SubPrompt::Omni(User, "Step 1 Level 0".to_string(), vec![], 100) ); assert_eq!( - step1.step_revisions[0].sub_prompts[1], + prompt1.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 1".to_string(), vec![], 100) ); assert_eq!( - step2.step_revisions[0].sub_prompts[0], + prompt2.sub_prompts[0], SubPrompt::Omni(User, "Step 2 Level 1".to_string(), vec![], 100) ); assert_eq!( - step2.step_revisions[0].sub_prompts[1], + prompt2.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 2".to_string(), vec![], 100) ); assert_eq!( - step4.step_revisions[0].sub_prompts[0], + prompt4.sub_prompts[0], SubPrompt::Omni(User, "Step 4 Level 2".to_string(), vec![], 100) ); assert_eq!( - step4.step_revisions[0].sub_prompts[1], + prompt4.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 4".to_string(), vec![], 100) ); } @@ -732,7 +714,7 @@ mod tests { shinkai_db .write() .await - .add_step_history(job_id.to_string(), user_message, None, agent_response, None, None) + .add_job_prompt(job_id.to_string(), user_message, None, agent_response, None, None) .unwrap(); // Update the parent message hash according to the tree structure @@ -767,16 +749,16 @@ mod tests { eprintln!("\n\n Getting steps..."); - let step_history = shinkai_db.read().await.get_step_history(job_id, true).unwrap().unwrap(); + let prompts = shinkai_db.read().await.get_job_prompts(job_id).unwrap(); - let step_history_content: Vec = step_history + let prompt_contents: Vec = prompts .iter() .map(|step| { - let user_message = match &step.step_revisions[0].sub_prompts[0] { + let user_message = match &step.sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }; - let agent_response = match &step.step_revisions[0].sub_prompts[1] { + let agent_response = match &step.sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }; @@ -784,19 +766,19 @@ mod tests { }) .collect(); - eprintln!("Step history: {:?}", step_history_content); + eprintln!("Step history: {:?}", prompt_contents); - assert_eq!(step_history.len(), 3); + assert_eq!(prompts.len(), 3); // Check the content of the steps assert_eq!( format!( "{} {}", - match &step_history[0].step_revisions[0].sub_prompts[0] { + match &prompts[0].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[0].step_revisions[0].sub_prompts[1] { + match &prompts[0].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } @@ -806,11 +788,11 @@ mod tests { assert_eq!( format!( "{} {}", - match &step_history[1].step_revisions[0].sub_prompts[0] { + match &prompts[1].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[1].step_revisions[0].sub_prompts[1] { + match &prompts[1].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } @@ -820,11 +802,11 @@ mod tests { assert_eq!( format!( "{} {}", - match &step_history[2].step_revisions[0].sub_prompts[0] { + match &prompts[2].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[2].step_revisions[0].sub_prompts[1] { + match &prompts[2].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs index fa56e0c0c..49a209424 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/job.rs @@ -1,4 +1,7 @@ -use crate::{shinkai_message::shinkai_message_schemas::AssociatedUI, shinkai_utils::job_scope::{JobScope, MinimalJobScope}}; +use crate::{ + shinkai_message::shinkai_message_schemas::AssociatedUI, + shinkai_utils::job_scope::{JobScope, MinimalJobScope}, +}; use super::{inbox_name::InboxName, job_config::JobConfig, prompts::Prompt}; use serde::{Deserialize, Serialize}; @@ -38,13 +41,11 @@ pub struct Job { /// An inbox where messages to the agent from the user and messages from the agent are stored, /// enabling each job to have a classical chat/conversation UI pub conversation_inbox_name: InboxName, - /// The job's step history (an ordered list of all prompts/outputs from LLM inferencing when processing steps) + /// List of Prompts that hold User->System sub prompt pairs that denote what the user + /// asked, and what the Agent finally responded with. /// Under the hood this is a tree, but it looks like a simple Vec because we only care about the latest valid path - /// based on the last message sent by the user - pub step_history: Vec, - /// A hashmap which holds a bunch of labeled values which were generated as output from the latest Job step - /// Same as step_history. Under the hood this is a tree, but everything is automagically filtered and converted to a hashmap. - pub execution_context: HashMap, + /// based on the last message sent by the user. + pub prompts: Vec, /// A link to the UI where the user can view the job e.g. Sheet UI pub associated_ui: Option, /// The job's configuration diff --git a/shinkai-libs/shinkai-message-primitives/src/schemas/prompts.rs b/shinkai-libs/shinkai-message-primitives/src/schemas/prompts.rs index 8c9f0bc55..00040d220 100644 --- a/shinkai-libs/shinkai-message-primitives/src/schemas/prompts.rs +++ b/shinkai-libs/shinkai-message-primitives/src/schemas/prompts.rs @@ -249,18 +249,6 @@ impl Prompt { self.add_sub_prompts(updated_sub_prompts); } - /// Adds previous results from step history into the Prompt, up to max_tokens - /// Of note, priority value must be between 0-100. - pub fn add_step_history(&mut self, history: Vec, priority_value: u8) { - let capped_priority_value = std::cmp::min(priority_value, 100) as u8; - let sub_prompts_list: Vec = history - .iter() - .filter_map(|step| step.get_result_prompt()) - .flat_map(|prompt| prompt.sub_prompts.clone()) - .collect(); - self.add_sub_prompts_with_new_priority(sub_prompts_list, capped_priority_value); - } - /// Removes the first sub-prompt from the end of the sub_prompts list that has the lowest priority value. /// Used primarily for cutting down prompt when it is too large to fit in context window. pub fn remove_lowest_priority_sub_prompt(&mut self) -> Option { @@ -418,12 +406,12 @@ impl Prompt { SubPrompt::Omni(prompt_type, _, _, _) => { // Process the current sub-prompt let new_message = sub_prompt.into_chat_completion_request_message(); - + if let SubPromptType::UserLastMessage = prompt_type { last_user_message = Some(new_message); } else { current_length += - sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message, token_counter); + sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message, token_counter); tiktoken_messages.push(new_message); } } diff --git a/shinkai-libs/shinkai-sqlite/src/job_manager.rs b/shinkai-libs/shinkai-sqlite/src/job_manager.rs index 0c7e79098..ea032a063 100644 --- a/shinkai-libs/shinkai-sqlite/src/job_manager.rs +++ b/shinkai-libs/shinkai-sqlite/src/job_manager.rs @@ -4,7 +4,7 @@ use rusqlite::params; use shinkai_message_primitives::{ schemas::{ inbox_name::InboxName, - job::{ForkedJob, Job, JobLike, JobStepResult}, + job::{ForkedJob, Job, JobLike}, job_config::JobConfig, prompts::Prompt, subprompts::SubPromptType, @@ -45,10 +45,9 @@ impl SqliteManager { scope, scope_with_files, conversation_inbox_name, - execution_context, associated_ui, config - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)", + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", )?; stmt.execute(params![ @@ -60,9 +59,6 @@ impl SqliteManager { scope_bytes, scope_with_files_bytes, job_inbox_name.clone(), - serde_json::to_vec(&HashMap::::new()).map_err(|e| { - rusqlite::Error::ToSqlConversionFailure(Box::new(SqliteManagerError::SerializationError(e.to_string()))) - })?, serde_json::to_vec(&associated_ui).map_err(|e| { rusqlite::Error::ToSqlConversionFailure(Box::new(SqliteManagerError::SerializationError(e.to_string()))) })?, @@ -114,7 +110,7 @@ impl SqliteManager { pub fn get_job_with_options( &self, job_id: &str, - fetch_step_history: bool, + fetch_prompts: bool, fetch_scope_with_files: bool, ) -> Result { let ( @@ -125,12 +121,11 @@ impl SqliteManager { datetime_created, parent_agent_id, conversation_inbox, - step_history, - execution_context, + prompts, associated_ui, config, forked_jobs, - ) = self.get_job_data(job_id, fetch_step_history, fetch_scope_with_files)?; + ) = self.get_job_data(job_id, fetch_prompts, fetch_scope_with_files)?; let job = Job { job_id: job_id.to_string(), @@ -141,8 +136,7 @@ impl SqliteManager { scope, scope_with_files, conversation_inbox_name: conversation_inbox, - step_history: step_history.unwrap_or_else(Vec::new), - execution_context, + prompts, associated_ui, config, forked_jobs, @@ -165,7 +159,6 @@ impl SqliteManager { parent_agent_id, conversation_inbox, _, - execution_context, associated_ui, config, forked_jobs, @@ -180,8 +173,7 @@ impl SqliteManager { scope, scope_with_files, conversation_inbox_name: conversation_inbox, - step_history: Vec::new(), // Empty step history for JobLike - execution_context, + prompts: Vec::new(), associated_ui, config, forked_jobs, @@ -194,7 +186,7 @@ impl SqliteManager { fn get_job_data( &self, job_id: &str, - fetch_step_history: bool, + fetch_prompts: bool, fetch_scope_with_files: bool, ) -> Result< ( @@ -205,8 +197,7 @@ impl SqliteManager { String, String, InboxName, - Option>, - HashMap, + Vec, Option, Option, Vec, @@ -230,7 +221,6 @@ impl SqliteManager { scope, {scope_with_files}, conversation_inbox_name, - execution_context, associated_ui, config FROM jobs WHERE job_id = ?1" @@ -248,9 +238,8 @@ impl SqliteManager { let inbox_name: String = row.get(7)?; let conversation_inbox: InboxName = InboxName::new(inbox_name).map_err(|e| SqliteManagerError::SomeError(e.to_string()))?; - let execution_context_bytes: Option> = row.get(8)?; - let associated_ui_bytes: Option> = row.get(9)?; - let config_bytes: Option> = row.get(10)?; + let associated_ui_bytes: Option> = row.get(8)?; + let config_bytes: Option> = row.get(9)?; let scope = serde_json::from_slice(&scope_bytes)?; let scope_with_files = if fetch_scope_with_files { @@ -260,13 +249,12 @@ impl SqliteManager { None }; - let step_history = if fetch_step_history { - self.get_step_history(job_id, true)? + let prompts = if fetch_prompts { + self.get_job_prompts(job_id)? } else { - None + Vec::new() }; - let execution_context = serde_json::from_slice(&execution_context_bytes.unwrap_or_default())?; let associated_ui = serde_json::from_slice(&associated_ui_bytes.unwrap_or_default())?; let config = serde_json::from_slice(&config_bytes.unwrap_or_default())?; @@ -293,8 +281,7 @@ impl SqliteManager { datetime_created, parent_agent_id, conversation_inbox, - step_history, - execution_context, + prompts, associated_ui, config, forked_jobs, @@ -319,13 +306,10 @@ impl SqliteManager { let inbox_name: String = row.get(7)?; let conversation_inbox: InboxName = InboxName::new(inbox_name).map_err(|e| SqliteManagerError::SomeError(e.to_string()))?; - let execution_context_bytes: Option> = row.get(8)?; - let associated_ui_bytes: Option> = row.get(9)?; - let config_bytes: Option> = row.get(10)?; + let associated_ui_bytes: Option> = row.get(8)?; + let config_bytes: Option> = row.get(9)?; let scope = serde_json::from_slice(&scope_bytes)?; let scope_with_files = serde_json::from_slice(&scope_with_files_bytes.unwrap_or_default())?; - let step_history = self.get_step_history(&job_id, false)?; - let execution_context = serde_json::from_slice(&execution_context_bytes.unwrap_or_default())?; let associated_ui = serde_json::from_slice(&associated_ui_bytes.unwrap_or_default())?; let config = serde_json::from_slice(&config_bytes.unwrap_or_default())?; @@ -354,8 +338,7 @@ impl SqliteManager { scope, scope_with_files, conversation_inbox_name: conversation_inbox, - step_history: step_history.unwrap_or_else(Vec::new), - execution_context, + prompts: Vec::new(), associated_ui, config, forked_jobs, @@ -398,13 +381,10 @@ impl SqliteManager { let inbox_name: String = row.get(7)?; let conversation_inbox: InboxName = InboxName::new(inbox_name).map_err(|e| SqliteManagerError::SomeError(e.to_string()))?; - let execution_context_bytes: Option> = row.get(8)?; - let associated_ui_bytes: Option> = row.get(9)?; - let config_bytes: Option> = row.get(10)?; + let associated_ui_bytes: Option> = row.get(8)?; + let config_bytes: Option> = row.get(9)?; let scope = serde_json::from_slice(&scope_bytes)?; let scope_with_files = serde_json::from_slice(&scope_with_files_bytes.unwrap_or_default())?; - let step_history = self.get_step_history(&job_id, false)?; - let execution_context = serde_json::from_slice(&execution_context_bytes.unwrap_or_default())?; let associated_ui = serde_json::from_slice(&associated_ui_bytes.unwrap_or_default())?; let config = serde_json::from_slice(&config_bytes.unwrap_or_default())?; @@ -432,8 +412,7 @@ impl SqliteManager { scope, scope_with_files, conversation_inbox_name: conversation_inbox, - step_history: step_history.unwrap_or_else(Vec::new), - execution_context, + prompts: Vec::new(), associated_ui, config, forked_jobs, @@ -445,36 +424,6 @@ impl SqliteManager { Ok(jobs.into_iter().map(|job| Box::new(job) as Box).collect()) } - pub fn set_job_execution_context( - &self, - job_id: String, - context: HashMap, - _message_key: Option, - ) -> Result<(), SqliteManagerError> { - let conn = self.get_connection()?; - - let context_bytes = serde_json::to_vec(&context)?; - - let mut stmt = conn.prepare("UPDATE jobs SET execution_context = ?1 WHERE job_id = ?2")?; - - stmt.execute(params![context_bytes, job_id])?; - - Ok(()) - } - - pub fn get_job_execution_context(&self, job_id: &str) -> Result, SqliteManagerError> { - let conn = self.get_connection()?; - let mut stmt = conn.prepare("SELECT execution_context FROM jobs WHERE job_id = ?1")?; - let mut rows = stmt.query(params![job_id])?; - - let row = rows.next()?.ok_or(SqliteManagerError::DataNotFound)?; - - let execution_context_bytes: Vec = row.get(0)?; - let execution_context = serde_json::from_slice(&execution_context_bytes)?; - - Ok(execution_context) - } - pub fn update_job_to_finished(&self, job_id: &str) -> Result<(), SqliteManagerError> { let conn = self.get_connection()?; let mut stmt = conn.prepare("SELECT COUNT(*) FROM jobs WHERE job_id = ?1")?; @@ -490,7 +439,7 @@ impl SqliteManager { Ok(()) } - pub fn add_step_history( + pub fn add_job_prompt( &self, job_id: String, user_message: String, @@ -522,38 +471,27 @@ impl SqliteManager { } }; - // Create prompt & JobStepResult + // Create prompt let mut prompt = Prompt::new(); let user_files = user_files.unwrap_or_default(); let agent_files = agent_files.unwrap_or_default(); prompt.add_omni(user_message, user_files, SubPromptType::User, 100); prompt.add_omni(agent_response, agent_files, SubPromptType::Assistant, 100); - let mut job_step_result = JobStepResult::new(); - job_step_result.add_new_step_revision(prompt); - let step_result_bytes = serde_json::to_vec(&job_step_result) + let prompt_bytes = serde_json::to_vec(&prompt) .map_err(|e| SqliteManagerError::SerializationError(format!("Error serializing JobStepResult: {}", e)))?; let conn = self.get_connection()?; - let mut stmt = - conn.prepare("INSERT INTO step_history (message_key, job_id, job_step_result) VALUES (?1, ?2, ?3)")?; - stmt.execute(params![message_key, job_id, step_result_bytes])?; + let mut stmt = conn.prepare("INSERT INTO job_prompts (message_key, job_id, prompt) VALUES (?1, ?2, ?3)")?; + stmt.execute(params![message_key, job_id, prompt_bytes])?; Ok(()) } - pub fn get_step_history( - &self, - job_id: &str, - fetch_step_history: bool, - ) -> Result>, SqliteManagerError> { - if !fetch_step_history { - return Ok(None); - } - + pub fn get_job_prompts(&self, job_id: &str) -> Result, SqliteManagerError> { let inbox_name = InboxName::get_job_inbox_name_from_params(job_id.to_string()) .map_err(|e| SqliteManagerError::SomeError(format!("Error getting inbox name: {}", e)))?; - let mut step_history: Vec = Vec::new(); + let mut prompts: Vec = Vec::new(); let mut until_offset_key: Option = None; let conn = self.get_connection()?; @@ -572,14 +510,14 @@ impl SqliteManager { for message_path in &messages { if let Some(message) = message_path.first() { let message_key = message.calculate_message_hash_for_pagination(); - let mut stmt = conn.prepare("SELECT job_step_result FROM step_history WHERE message_key = ?1")?; + let mut stmt = conn.prepare("SELECT prompt FROM job_prompts WHERE message_key = ?1")?; let mut rows = stmt.query(params![message_key])?; while let Some(row) = rows.next()? { let step_result_bytes: Vec = row.get(0)?; - let step_result: JobStepResult = serde_json::from_slice(&step_result_bytes)?; + let step_result: Prompt = serde_json::from_slice(&step_result_bytes)?; - step_history.push(step_result); + prompts.push(step_result); } } } @@ -595,9 +533,9 @@ impl SqliteManager { } } - // Reverse the step history before returning - step_history.reverse(); - Ok(Some(step_history)) + // Reverse the prompts before returning + prompts.reverse(); + Ok(prompts) } pub fn is_job_inbox_empty(&self, job_id: &str) -> Result { @@ -651,7 +589,7 @@ impl SqliteManager { params![inbox_name.to_string()], )?; - tx.execute("DELETE FROM step_history WHERE job_id = ?1", params![job_id])?; + tx.execute("DELETE FROM job_prompts WHERE job_id = ?1", params![job_id])?; tx.execute("DELETE FROM jobs WHERE job_id = ?1", params![job_id])?; tx.commit()?; @@ -837,7 +775,7 @@ mod tests { } #[tokio::test] - async fn test_update_step_history() { + async fn test_update_job_prompts() { let db = setup_test_db(); let job_id = "test_job".to_string(); @@ -866,7 +804,7 @@ mod tests { db.unsafe_insert_inbox_message(&message, None, None).await.unwrap(); // Update step history - db.add_step_history( + db.add_job_prompt( job_id.clone(), "What is 10 + 25".to_string(), None, @@ -876,7 +814,7 @@ mod tests { ) .unwrap(); sleep(Duration::from_millis(10)).await; - db.add_step_history( + db.add_job_prompt( job_id.clone(), "2) What is 10 + 25".to_string(), None, @@ -888,7 +826,7 @@ mod tests { // Retrieve the job and check that step history is updated let job = db.get_job(&job_id.clone()).unwrap(); - assert_eq!(job.step_history.len(), 2); + assert_eq!(job.prompts.len(), 2); } #[test] @@ -1076,7 +1014,7 @@ mod tests { } #[tokio::test] - async fn test_job_inbox_tree_structure_with_step_history_and_execution_context() { + async fn test_job_inbox_tree_structure_with_job_prompt() { let db = setup_test_db(); let job_id = "job_test".to_string(); let agent_id = "agent_test".to_string(); @@ -1128,9 +1066,9 @@ mod tests { .add_message_to_job_inbox(&job_id.clone(), &shinkai_message, parent_hash.clone(), None) .await; - // Add a step history + // Add a prompt let result = format!("Result {}", i); - db.add_step_history( + db.add_job_prompt( job_id.clone(), format!("Step {} Level {}", i, current_level), None, @@ -1143,12 +1081,6 @@ mod tests { // Add the result to the results vector results.push(result); - // Set job execution context - let mut execution_context = HashMap::new(); - execution_context.insert("context".to_string(), results.join(", ")); - db.set_job_execution_context(job_id.clone(), execution_context, None) - .unwrap(); - // Update the parent message according to the tree structure if i == 1 { parent_message_hash = Some(shinkai_message.calculate_message_hash_for_pagination()); @@ -1196,45 +1128,36 @@ mod tests { let job_message_4: JobMessage = serde_json::from_str(&message_content_4).unwrap(); assert_eq!(job_message_4.content, "Hello World 4".to_string()); - // Check the step history and execution context + // Check the prompts let job = db.get_job(&job_id.clone()).unwrap(); - eprintln!("job execution context: {:?}", job.execution_context); - - // Check the execution context - assert_eq!( - job.execution_context.get("context").unwrap(), - "Result 1, Result 2, Result 4" - ); - - // Check the step history - let step1 = &job.step_history[0]; - let step2 = &job.step_history[1]; - let step4 = &job.step_history[2]; + let prompt1 = &job.prompts[0]; + let prompt2 = &job.prompts[1]; + let prompt4 = &job.prompts[2]; assert_eq!( - step1.step_revisions[0].sub_prompts[0], + prompt1.sub_prompts[0], SubPrompt::Omni(User, "Step 1 Level 0".to_string(), vec![], 100) ); assert_eq!( - step1.step_revisions[0].sub_prompts[1], + prompt1.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 1".to_string(), vec![], 100) ); assert_eq!( - step2.step_revisions[0].sub_prompts[0], + prompt2.sub_prompts[0], SubPrompt::Omni(User, "Step 2 Level 1".to_string(), vec![], 100) ); assert_eq!( - step2.step_revisions[0].sub_prompts[1], + prompt2.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 2".to_string(), vec![], 100) ); assert_eq!( - step4.step_revisions[0].sub_prompts[0], + prompt4.sub_prompts[0], SubPrompt::Omni(User, "Step 4 Level 2".to_string(), vec![], 100) ); assert_eq!( - step4.step_revisions[0].sub_prompts[1], + prompt4.sub_prompts[1], SubPrompt::Omni(Assistant, "Result 4".to_string(), vec![], 100) ); } @@ -1293,7 +1216,7 @@ mod tests { .await .unwrap(); - db.add_step_history(job_id.to_string(), user_message, None, agent_response, None, None) + db.add_job_prompt(job_id.to_string(), user_message, None, agent_response, None, None) .unwrap(); // Update the parent message hash according to the tree structure @@ -1324,16 +1247,16 @@ mod tests { eprintln!("\n\n Getting steps..."); - let step_history = db.get_step_history(job_id, true).unwrap().unwrap(); + let job_prompt = db.get_job_prompts(job_id).unwrap(); - let step_history_content: Vec = step_history + let prompt_contents: Vec = job_prompt .iter() - .map(|step| { - let user_message = match &step.step_revisions[0].sub_prompts[0] { + .map(|prompt| { + let user_message = match &prompt.sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }; - let agent_response = match &step.step_revisions[0].sub_prompts[1] { + let agent_response = match &prompt.sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }; @@ -1341,19 +1264,19 @@ mod tests { }) .collect(); - eprintln!("Step history: {:?}", step_history_content); + eprintln!("Step history: {:?}", prompt_contents); - assert_eq!(step_history.len(), 3); + assert_eq!(job_prompt.len(), 3); // Check the content of the steps assert_eq!( format!( "{} {}", - match &step_history[0].step_revisions[0].sub_prompts[0] { + match &job_prompt[0].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[0].step_revisions[0].sub_prompts[1] { + match &job_prompt[0].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } @@ -1363,11 +1286,11 @@ mod tests { assert_eq!( format!( "{} {}", - match &step_history[1].step_revisions[0].sub_prompts[0] { + match &job_prompt[1].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[1].step_revisions[0].sub_prompts[1] { + match &job_prompt[1].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } @@ -1377,11 +1300,11 @@ mod tests { assert_eq!( format!( "{} {}", - match &step_history[2].step_revisions[0].sub_prompts[0] { + match &job_prompt[2].sub_prompts[0] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), }, - match &step_history[2].step_revisions[0].sub_prompts[1] { + match &job_prompt[2].sub_prompts[1] { SubPrompt::Omni(_, text, _, _) => text, _ => panic!("Unexpected SubPrompt variant"), } diff --git a/shinkai-libs/shinkai-sqlite/src/lib.rs b/shinkai-libs/shinkai-sqlite/src/lib.rs index 647d51947..7b8c848d0 100644 --- a/shinkai-libs/shinkai-sqlite/src/lib.rs +++ b/shinkai-libs/shinkai-sqlite/src/lib.rs @@ -167,7 +167,7 @@ impl SqliteManager { Self::initialize_settings_table(conn)?; Self::initialize_sheets_table(conn)?; Self::initialize_source_file_maps_table(conn)?; - Self::initialize_step_history_table(conn)?; + Self::initialize_job_prompts_table(conn)?; Self::initialize_tools_table(conn)?; Self::initialize_tool_micropayments_requirements_table(conn)?; Self::initialize_tool_playground_table(conn)?; @@ -329,7 +329,6 @@ impl SqliteManager { scope BLOB NOT NULL, scope_with_files BLOB, conversation_inbox_name TEXT NOT NULL, - execution_context BLOB, associated_ui BLOB, config BLOB );", @@ -339,12 +338,12 @@ impl SqliteManager { Ok(()) } - fn initialize_step_history_table(conn: &rusqlite::Connection) -> Result<()> { + fn initialize_job_prompts_table(conn: &rusqlite::Connection) -> Result<()> { conn.execute( - "CREATE TABLE IF NOT EXISTS step_history ( + "CREATE TABLE IF NOT EXISTS job_prompts ( message_key TEXT NOT NULL, job_id TEXT NOT NULL, - job_step_result BLOB NOT NULL + prompt BLOB NOT NULL );", [], )?;