diff --git a/Cargo.lock b/Cargo.lock index b09a9a357..6502070a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8945,7 +8945,7 @@ dependencies = [ [[package]] name = "shinkai_node" -version = "0.8.3" +version = "0.8.4" dependencies = [ "aes-gcm", "anyhow", diff --git a/shinkai-bin/shinkai-node/Cargo.toml b/shinkai-bin/shinkai-node/Cargo.toml index f6a444b75..caa9b73bc 100644 --- a/shinkai-bin/shinkai-node/Cargo.toml +++ b/shinkai-bin/shinkai-node/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "shinkai_node" -version = "0.8.3" +version = "0.8.4" edition = "2021" authors.workspace = true # this causes `cargo run` in the workspace root to run this package diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs index 101bebde8..358bcbc74 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/chains/generic_chain/generic_prompts.rs @@ -32,11 +32,6 @@ impl JobPromptGenerator { // Add previous messages // TODO: this should be full messages with assets and not just strings if let Some(step_history) = job_step_history { - for step in &step_history { - if let Some(prompt) = step.get_result_prompt() { - println!("Step history content: {:?}", prompt); - } - } prompt.add_step_history(step_history, 97); } diff --git a/shinkai-bin/shinkai-node/src/llm_provider/execution/prompts/prompts.rs b/shinkai-bin/shinkai-node/src/llm_provider/execution/prompts/prompts.rs index a19ebc0f9..dc5f7faeb 100644 --- a/shinkai-bin/shinkai-node/src/llm_provider/execution/prompts/prompts.rs +++ b/shinkai-bin/shinkai-node/src/llm_provider/execution/prompts/prompts.rs @@ -62,7 +62,13 @@ impl Prompt { /// Adds a sub-prompt that holds any Omni (String + Assets) content. /// Of note, priority value must be between 0-100, where higher is greater priority - pub fn add_omni(&mut self, content: String, files: HashMap, prompt_type: SubPromptType, priority_value: u8) { + pub fn add_omni( + &mut self, + content: String, + files: HashMap, + prompt_type: SubPromptType, + priority_value: u8, + ) { let capped_priority_value = std::cmp::min(priority_value, 100); let assets: Vec<(SubPromptAssetType, SubPromptAssetContent, SubPromptAssetDetail)> = files .into_iter() @@ -306,7 +312,7 @@ impl Prompt { // Accumulator for ExtraContext content let mut extra_context_content = String::new(); - let mut last_user_message: Option = None; + let mut last_user_message: Option = None; let mut function_calls: Vec = Vec::new(); let mut function_call_responses: Vec = Vec::new(); @@ -366,13 +372,25 @@ impl Prompt { function_call_responses.push(new_message); } SubPrompt::Content(SubPromptType::UserLastMessage, content, _) => { - last_user_message = Some(content.clone()); + last_user_message = Some(LlmMessage { + role: Some(SubPromptType::User.to_string()), + content: Some(content.clone()), + name: None, + function_call: None, + functions: None, + images: None, + }); } - SubPrompt::Omni(_, _, _, _) => { + SubPrompt::Omni(prompt_type, _, _, _) => { // Process the current sub-prompt let new_message = sub_prompt.into_chat_completion_request_message(); current_length += sub_prompt.count_tokens_with_pregenerated_completion_message(&new_message); - tiktoken_messages.push(new_message); + + if let SubPromptType::UserLastMessage = prompt_type { + last_user_message = Some(new_message); + } else { + tiktoken_messages.push(new_message); + } } _ => { // Process the current sub-prompt @@ -388,18 +406,21 @@ impl Prompt { let combined_content = format!( "{}\n{}", extra_context_content.trim(), - last_user_message.unwrap_or_default() + last_user_message + .as_ref() + .and_then(|msg| msg.content.clone()) + .unwrap_or_default() ) .trim() .to_string(); - let combined_message = LlmMessage { + let mut combined_message = LlmMessage { role: Some(SubPromptType::User.to_string()), content: Some(combined_content), name: None, function_call: None, functions: None, - images: None, + images: last_user_message.and_then(|msg| msg.images), }; current_length += ModelCapabilitiesManager::num_tokens_from_llama3(&[combined_message.clone()]); tiktoken_messages.push(combined_message);