Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nico/job message middleware #723

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ use crate::network::agent_payments_manager::external_agent_offerings_manager::Ex
use crate::network::agent_payments_manager::my_agent_offerings_manager::MyAgentOfferingsManager;
use ed25519_dalek::SigningKey;

use image::error;
use shinkai_job_queue_manager::job_queue_manager::{JobForProcessing, JobQueueManager};
use shinkai_message_primitives::schemas::job::{Job, JobLike};
use shinkai_message_primitives::schemas::llm_providers::common_agent_llm_provider::ProviderOrAgent;
use shinkai_message_primitives::schemas::sheet::WorkflowSheetJobData;
use shinkai_message_primitives::schemas::ws_types::WSUpdateHandler;
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{CallbackAction, MessageMetadata};
use shinkai_message_primitives::shinkai_message::shinkai_message_schemas::{
CallbackAction, MessageMetadata, MiddlewareTool,
};
use shinkai_message_primitives::shinkai_utils::job_scope::{
LocalScopeVRKaiEntry, LocalScopeVRPackEntry, ScopeEntry, VectorFSFolderScopeEntry, VectorFSItemScopeEntry,
};
Expand All @@ -37,6 +38,9 @@ use std::time::Instant;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{Mutex, RwLock};

use super::chains::inference_chain_trait::{FunctionCall, InferenceChainContext};
use super::user_message_parser::ParsedUserMessage;

impl JobManager {
/// Processes a job message which will trigger a job step
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -119,6 +123,27 @@ impl JobManager {
return Self::handle_error(&db, Some(user_profile), &job_id, &identity_secret_key, e, ws_manager).await;
}

// 1.5.- Process middleware tools
if let Some(middleware_tools) = &job_message.job_message.middleware_tools {
if let Some(tool_router) = &tool_router {
Self::process_middleware_tools(
tool_router,
&job_message.job_message,
middleware_tools,
user_profile.clone(),
db.clone(),
vector_fs.clone(),
&full_job,
generator.clone(),
ws_manager.clone(),
llm_stopper.clone(),
llm_provider_found.clone(),
)
.await?;
// if we return a string here, we need to update the job message to have the new content
}
}

// 2.- *If* a sheet job is found, processing job message is taken over by this alternate logic
let sheet_job_found = JobManager::process_sheet_job(
db.clone(),
Expand Down Expand Up @@ -908,4 +933,88 @@ impl JobManager {

Ok(files_map)
}

/// Process middleware tools by getting the tool and calling its function
pub async fn process_middleware_tools(
tool_router: &Arc<ToolRouter>,
job_message: &JobMessage,
middleware_tools: &Vec<MiddlewareTool>,
user_profile: ShinkaiName,
db: Arc<RwLock<SqliteManager>>,
vector_fs: Arc<VectorFS>,
full_job: &Job,
generator: RemoteEmbeddingGenerator,
ws_manager: Option<Arc<Mutex<dyn WSUpdateHandler + Send>>>,
llm_stopper: Arc<LLMStopper>,
llm_provider_found: Option<ProviderOrAgent>,
) -> Result<(), LLMProviderError> {
// ^ eddie do we want to return a string after all of the processing?
let llm_provider = llm_provider_found.ok_or(LLMProviderError::LLMProviderNotFound)?;
let model = {
if let ProviderOrAgent::LLMProvider(llm_provider) = llm_provider.clone() {
&llm_provider.model.clone()
} else {
// If it's an agent, we need to get the LLM provider from the agent
let llm_id = llm_provider.get_llm_provider_id();
let llm_provider = db
.read()
.await
.get_llm_provider(llm_id, &user_profile)
.map_err(|e| e.to_string())?
.ok_or(LLMProviderError::LLMProviderNotFound)?;
&llm_provider.model.clone()
}
};
let max_tokens_in_prompt = ModelCapabilitiesManager::get_max_input_tokens(&model);
let parsed_user_message = ParsedUserMessage::new(job_message.content.to_string());

for middleware_tool in middleware_tools {
// Get the tool router key - required for tool lookup
let tool_router_key = middleware_tool.tool_router_key.clone();

// Get the tool by name and handle the Option and Result
let shinkai_tool = tool_router
.get_tool_by_name(&tool_router_key)
.await
.map_err(|e| LLMProviderError::ToolRouterError(format!("Failed to get tool: {}", e)))?
.ok_or_else(|| LLMProviderError::ToolRouterError(format!("Tool not found: {}", tool_router_key)))?;

// Create proper context for the tool execution
let context = InferenceChainContext::new(
db.clone(),
vector_fs.clone(),
full_job.clone(),
parsed_user_message.clone(),
job_message.tool_key.clone(),
None,
HashMap::new(),
llm_provider.clone(),
HashMap::new(),
generator.clone(),
user_profile.clone(),
3,
max_tokens_in_prompt,
ws_manager.clone(),
Some(tool_router.clone()),
None, // No sheet manager needed
None, // No agent payments manager needed
None, // No external agent payments manager needed
llm_stopper.clone(),
);

let function_call = FunctionCall {
name: "".to_string(), // Eddie more logic here
arguments: serde_json::Map::new(), // Eddie more logic here
tool_router_key: None, // Not needed
};

// Call the function with the context
tool_router
.call_function(function_call, &context, &shinkai_tool, user_profile.clone())
.await
.map_err(|e| LLMProviderError::ToolRouterError(format!("Failed to call function: {}", e)))?;
}

Ok(())
}
}
1 change: 1 addition & 0 deletions shinkai-bin/shinkai-node/src/managers/sheet_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ impl SheetManager {
callback: None,
metadata: None,
tool_key: None,
middleware_tools: None,
};

job_messages.push((job_message, job_data));
Expand Down
2 changes: 1 addition & 1 deletion shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ impl ToolRouter {
match shinkai_tool {
ShinkaiTool::Python(_, _) => {
return Ok(ToolCallFunctionResponse {
response: "Deno!".to_string(),
response: "🐍 Python not connected".to_string(),
function_call,
});
}
Expand Down
28 changes: 28 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1485,6 +1485,34 @@ impl Node {
.await;
});
}
NodeCommand::V2ApiAddMessagesGodMode {
bearer,
job_id,
messages,
res,
} => {
let db_clone = self.db.clone();
let node_name_clone = self.node_name.clone();
let identity_manager_clone = self.identity_manager.clone();
let encryption_secret_key_clone = self.encryption_secret_key.clone();
let encryption_public_key_clone = self.encryption_public_key;
let signing_secret_key_clone = self.identity_secret_key.clone();
tokio::spawn(async move {
let _ = Node::v2_add_messages_god_mode(
db_clone,
node_name_clone,
identity_manager_clone,
bearer,
job_id,
messages,
encryption_secret_key_clone,
encryption_public_key_clone,
signing_secret_key_clone,
res,
)
.await;
});
}
NodeCommand::V2ApiGetLastMessagesFromInbox {
bearer,
inbox_name,
Expand Down
160 changes: 159 additions & 1 deletion shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_jobs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use shinkai_message_primitives::{
V2ChatMessage,
},
},
shinkai_utils::job_scope::JobScope,
shinkai_utils::{job_scope::JobScope, shinkai_message_builder::ShinkaiMessageBuilder, signatures::clone_signature_secret_key},
};

use shinkai_sqlite::SqliteManager;
Expand Down Expand Up @@ -1610,4 +1610,162 @@ impl Node {

Ok(())
}

pub async fn v2_add_messages_god_mode(
db: Arc<RwLock<SqliteManager>>,
node_name: ShinkaiName,
identity_manager: Arc<Mutex<IdentityManager>>,
bearer: String,
job_id: String,
messages: Vec<JobMessage>,
node_encryption_sk: EncryptionStaticKey,
node_encryption_pk: EncryptionPublicKey,
node_signing_sk: SigningKey,
res: Sender<Result<String, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Get the main identity from the identity manager
let main_identity = {
let identity_manager = identity_manager.lock().await;
match identity_manager.get_main_identity() {
Some(identity) => identity.clone(),
None => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: "Failed to get main identity".to_string(),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}
};

// Retrieve the job to get the llm_provider
let llm_provider = match db.read().await.get_job_with_options(&job_id, false, false) {
Ok(job) => job.parent_agent_or_llm_provider_id.clone(),
Err(err) => {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("Job with ID {} not found: {}", job_id, err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

// Process each message alternating between user and AI
for (index, message) in messages.into_iter().enumerate() {
if index % 2 == 0 {
// User message
let sender = match ShinkaiName::new(main_identity.get_full_identity_name()) {
Ok(name) => name,
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create sender name: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

let recipient = match ShinkaiName::from_node_and_profile_names_and_type_and_name(
node_name.node_name.clone(),
"main".to_string(),
ShinkaiSubidentityType::Agent,
llm_provider.clone(),
) {
Ok(name) => name,
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create recipient name: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

let shinkai_message = match Self::api_v2_create_shinkai_message(
sender,
recipient,
&serde_json::to_string(&message).unwrap(),
MessageSchemaType::JobMessageSchema,
node_encryption_sk.clone(),
node_signing_sk.clone(),
node_encryption_pk,
Some(job_id.clone()),
) {
Ok(message) => message,
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create Shinkai message: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

// Add the user message to the job inbox
if let Err(err) = db
.write()
.await
.add_message_to_job_inbox(&job_id, &shinkai_message, None, None)
.await
{
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to add user message to job inbox: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
} else {
// AI message
let identity_secret_key_clone = clone_signature_secret_key(&node_signing_sk);
let ai_shinkai_message = ShinkaiMessageBuilder::job_message_from_llm_provider(
job_id.to_string(),
message.content,
message.files_inbox,
None,
identity_secret_key_clone,
node_name.node_name.clone(),
node_name.node_name.clone(),
)
.expect("Failed to build AI message");

// Add the AI message to the job inbox
if let Err(err) = db
.write()
.await
.add_message_to_job_inbox(&job_id, &ai_shinkai_message, None, None)
.await
{
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to add AI message to job inbox: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}
}

// Send success response
let _ = res.send(Ok("Messages added successfully".to_string())).await;

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,7 @@ impl Node {
callback: None,
metadata: None,
tool_key: None,
middleware_tools: None,
};

let shinkai_message = match Self::api_v2_create_shinkai_message(
Expand Down
1 change: 1 addition & 0 deletions shinkai-bin/shinkai-node/src/tools/tool_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub async fn v2_send_basic_job_message_for_existing_job(
callback: None,
metadata: None,
tool_key: None,
middleware_tools: None,
};

let (res_sender, res_receiver) = async_channel::bounded(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ async fn test_process_job_queue_concurrency() {
callback: None,
metadata: None,
tool_key: None,
middleware_tools: None,
},
ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(),
None,
Expand Down Expand Up @@ -371,6 +372,7 @@ async fn test_sequential_process_for_same_job_id() {
callback: None,
metadata: None,
tool_key: None,
middleware_tools: None,
},
ShinkaiName::new("@@node1.shinkai/main".to_string()).unwrap(),
None,
Expand Down
Loading