From d4459392ea6594efbfa996dec7235d889b4fec45 Mon Sep 17 00:00:00 2001 From: Eddie Date: Wed, 18 Dec 2024 15:46:14 -0300 Subject: [PATCH] oauth for tool execution in chat/api --- .../shinkai-node/src/managers/tool_router.rs | 97 ++++++++++--------- .../tool_execution/execution_coordinator.rs | 38 +++----- .../tool_execution/execution_deno_dynamic.rs | 24 ++--- .../execution_header_generator.rs | 31 ++++++ .../execution_python_dynamic.rs | 23 ++--- .../src/tools/tool_execution/mod.rs | 1 + 6 files changed, 115 insertions(+), 99 deletions(-) create mode 100644 shinkai-bin/shinkai-node/src/tools/tool_execution/execution_header_generator.rs diff --git a/shinkai-bin/shinkai-node/src/managers/tool_router.rs b/shinkai-bin/shinkai-node/src/managers/tool_router.rs index a814be1dc..e48f5bd20 100644 --- a/shinkai-bin/shinkai-node/src/managers/tool_router.rs +++ b/shinkai-bin/shinkai-node/src/managers/tool_router.rs @@ -6,6 +6,7 @@ use std::time::Instant; use crate::llm_provider::error::LLMProviderError; use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, InferenceChainContextTrait}; use crate::tools::tool_definitions::definition_generation::{generate_tool_definitions, get_rust_tools}; +use crate::tools::tool_execution::execution_header_generator::generate_execution_environment; use crate::utils::environment::fetch_node_environment; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -268,7 +269,10 @@ impl ToolRouter { // Check if ADD_TESTING_NETWORK_ECHO is set if std::env::var("ADD_TESTING_NETWORK_ECHO").unwrap_or_else(|_| "false".to_string()) == "true" { - match self.sqlite_manager.get_tool_by_key("local:::shinkai-tool-echo:::shinkai__echo") { + match self + .sqlite_manager + .get_tool_by_key("local:::shinkai-tool-echo:::shinkai__echo") + { Ok(shinkai_tool) => { if let ShinkaiTool::Deno(mut js_tool, _) = shinkai_tool { js_tool.name = "network__echo".to_string(); @@ -522,24 +526,19 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT: generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false) .await .map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?; - let mut envs = HashMap::new(); - - let bearer = context.db().read_api_v2_key().unwrap_or_default().unwrap_or_default(); - let llm_provider = context.agent().clone().get_id().to_string(); - envs.insert("BEARER".to_string(), bearer); - envs.insert( - "X_SHINKAI_TOOL_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert( - "X_SHINKAI_APP_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert( - "X_SHINKAI_INSTANCE_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider); + + let envs = generate_execution_environment( + context.db(), + context.agent().clone().get_id().to_string(), + format!("jid-{}", tool_id), + format!("jid-{}", app_id), + shinkai_tool.tool_router_key().clone(), + format!("jid-{}", app_id), + &python_tool.oauth, + ) + .await + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + let result = python_tool .run( envs, @@ -597,27 +596,18 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT: generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false) .await .map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?; - let mut envs = HashMap::new(); - let bearer = context - .db() - .read_api_v2_key() - .unwrap_or_default() - .unwrap_or_default(); - let llm_provider = context.agent().clone().get_id().to_string(); - envs.insert("BEARER".to_string(), bearer); - envs.insert( - "X_SHINKAI_TOOL_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert( - "X_SHINKAI_APP_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert( - "X_SHINKAI_INSTANCE_ID".to_string(), - format!("jid-{}", context.full_job().job_id()), - ); - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider); + let envs = generate_execution_environment( + context.db(), + context.agent().clone().get_id().to_string(), + format!("jid-{}", app_id), + format!("jid-{}", tool_id), + shinkai_tool.tool_router_key().clone(), + format!("jid-{}", app_id), + &deno_tool.oauth, + ) + .await + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; + let result = deno_tool .run( envs, @@ -945,16 +935,29 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT: generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false) .await .map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?; - let mut envs = HashMap::new(); - envs.insert("BEARER".to_string(), "".to_string()); // TODO (How do we get the bearer?) - envs.insert("X_SHINKAI_TOOL_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_APP_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), "".to_string()); // TODO Pass data from the API + + let oauth = match shinkai_tool.clone() { + ShinkaiTool::Deno(deno_tool, _) => deno_tool.oauth.clone(), + ShinkaiTool::Python(python_tool, _) => python_tool.oauth.clone(), + _ => return Err(LLMProviderError::FunctionNotFound(js_tool_name.to_string())), + }; + + let env = generate_execution_environment( + self.sqlite_manager.clone(), + "".to_string(), + format!("xid-{}", app_id), + format!("xid-{}", tool_id), + shinkai_tool.tool_router_key().clone(), + // TODO: Pass data from the API + "".to_string(), + &oauth, + ) + .await + .map_err(|e| ToolError::ExecutionError(e.to_string()))?; let result = js_tool .run( - HashMap::new(), + env, node_env.api_listen_address.ip().to_string(), node_env.api_listen_address.port(), support_files, diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs index c1f61cbc3..2e2032648 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_coordinator.rs @@ -2,6 +2,7 @@ use crate::llm_provider::job_manager::JobManager; use crate::tools::tool_definitions::definition_generation::generate_tool_definitions; use crate::tools::tool_execution::execution_custom::execute_custom_tool; use crate::tools::tool_execution::execution_deno_dynamic::{check_deno_tool, execute_deno_tool}; +use crate::tools::tool_execution::execution_header_generator::generate_execution_environment; use crate::tools::tool_execution::execution_python_dynamic::execute_python_tool; use crate::utils::environment::fetch_node_environment; @@ -99,10 +100,10 @@ pub async fn handle_oauth( let oauth_login_url = format!( "{}?client_id={}&redirect_uri={}&scope={}&state={}", o.authorization_url, - o.client_id, + urlencoding::encode(&o.client_id), urlencoding::encode(&o.redirect_url), - o.scopes.join(" "), - uuid + urlencoding::encode(&o.scopes.join(" ")), + urlencoding::encode(&uuid) ); return Err(ToolError::OAuthError(oauth_login_url)); @@ -157,26 +158,19 @@ pub async fn execute_tool_cmd( .get_tool_by_key(&tool_router_key) .map_err(|e| ToolError::ExecutionError(format!("Failed to get tool: {}", e)))?; - let mut envs = HashMap::new(); - envs.insert("BEARER".to_string(), bearer); - envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone()); - envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone()); - envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone()); - match tool { ShinkaiTool::Python(python_tool, _) => { - let oauth = handle_oauth( - &python_tool.oauth, - &db, + let env = generate_execution_environment( + db.clone(), + llm_provider.clone(), app_id.clone(), tool_id.clone(), tool_router_key.clone(), + "".to_string(), // TODO Pass data from the API + &python_tool.oauth.clone(), ) .await?; - envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string()); - let node_env = fetch_node_environment(); let node_storage_path = node_env .node_storage_path @@ -192,7 +186,7 @@ pub async fn execute_tool_cmd( .map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?; python_tool .run( - envs, + env, node_env.api_listen_address.ip().to_string(), node_env.api_listen_address.port(), support_files, @@ -208,17 +202,17 @@ pub async fn execute_tool_cmd( .map(|result| json!(result.data)) } ShinkaiTool::Deno(deno_tool, _) => { - let oauth = handle_oauth( - &deno_tool.oauth, - &db, + let env = generate_execution_environment( + db.clone(), + llm_provider.clone(), app_id.clone(), tool_id.clone(), tool_router_key.clone(), + "".to_string(), // TODO Pass data from the API + &deno_tool.oauth.clone(), ) .await?; - envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string()); - let node_env = fetch_node_environment(); let node_storage_path = node_env .node_storage_path @@ -234,7 +228,7 @@ pub async fn execute_tool_cmd( .map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?; deno_tool .run( - envs, + env, node_env.api_listen_address.ip().to_string(), node_env.api_listen_address.port(), support_files, diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs index 8f0b2de0a..792be68cc 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_deno_dynamic.rs @@ -9,14 +9,13 @@ use shinkai_tools_primitives::tools::parameters::Parameters; use shinkai_tools_primitives::tools::tool_config::{OAuth, ToolConfig}; use shinkai_tools_primitives::tools::tool_output_arg::ToolOutputArg; -use super::execution_coordinator::handle_oauth; +use super::execution_header_generator::generate_execution_environment; use crate::utils::environment::fetch_node_environment; use shinkai_sqlite::SqliteManager; use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; pub async fn execute_deno_tool( - bearer: String, + _bearer: String, db: Arc, node_name: ShinkaiName, parameters: Map, @@ -50,22 +49,17 @@ pub async fn execute_deno_tool( assets: None, }; - let mut envs = HashMap::new(); - envs.insert("BEARER".to_string(), bearer); - envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone()); - envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone()); - envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone()); - - let oauth = handle_oauth( - &oauth.clone(), - &db, + let env = generate_execution_environment( + db.clone(), + llm_provider.clone(), app_id.clone(), tool_id.clone(), + // TODO: Update this value for runtime tool execution "code-execution".to_string(), + "".to_string(), + &oauth.clone(), ) .await?; - envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string()); let node_env = fetch_node_environment(); let node_storage_path = node_env @@ -95,7 +89,7 @@ pub async fn execute_deno_tool( } match tool.run_on_demand( - envs, + env, node_env.api_listen_address.ip().to_string(), node_env.api_listen_address.port(), support_files, diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_header_generator.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_header_generator.rs new file mode 100644 index 000000000..fdf21f9de --- /dev/null +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_header_generator.rs @@ -0,0 +1,31 @@ +use std::{collections::HashMap, sync::Arc}; + +use shinkai_sqlite::SqliteManager; +use shinkai_tools_primitives::tools::{error::ToolError, tool_config::OAuth}; + +use super::execution_coordinator::handle_oauth; + +pub async fn generate_execution_environment( + db: Arc, + llm_provider: String, + app_id: String, + tool_id: String, + tool_router_key: String, + instance_id: String, + oauth: &Option>, +) -> Result, ToolError> { + let mut envs = HashMap::new(); + + let bearer = db.read_api_v2_key().unwrap_or_default().unwrap_or_default(); + envs.insert("BEARER".to_string(), bearer); + envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone()); + envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone()); + envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), instance_id.clone()); + envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider); + + let oauth = handle_oauth(oauth, &db, app_id.clone(), tool_id.clone(), tool_router_key.clone()).await?; + + envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string()); + + Ok(envs) +} diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs index ba2427daa..fadc9ae86 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/execution_python_dynamic.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, path::PathBuf}; -use super::execution_coordinator::handle_oauth; +use super::execution_header_generator::generate_execution_environment; use crate::utils::environment::fetch_node_environment; use serde_json::{Map, Value}; use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName; @@ -14,10 +14,9 @@ use shinkai_tools_primitives::tools::{ tool_output_arg::ToolOutputArg, }; use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; pub async fn execute_python_tool( - bearer: String, + _bearer: String, db: Arc, node_name: ShinkaiName, parameters: Map, @@ -51,22 +50,16 @@ pub async fn execute_python_tool( assets: None, }; - let mut envs = HashMap::new(); - envs.insert("BEARER".to_string(), bearer); - envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone()); - envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone()); - envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API - envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone()); - - let oauth = handle_oauth( - &oauth.clone(), - &db, + let env = generate_execution_environment( + db.clone(), + llm_provider.clone(), app_id.clone(), tool_id.clone(), "code-execution".to_string(), + "".to_string(), + &oauth.clone(), ) .await?; - envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string()); let node_env = fetch_node_environment(); let node_storage_path = node_env @@ -96,7 +89,7 @@ pub async fn execute_python_tool( } match tool.run_on_demand( - envs, + env, node_env.api_listen_address.ip().to_string(), node_env.api_listen_address.port(), support_files, diff --git a/shinkai-bin/shinkai-node/src/tools/tool_execution/mod.rs b/shinkai-bin/shinkai-node/src/tools/tool_execution/mod.rs index 1e85f8b12..6b12b9fa9 100644 --- a/shinkai-bin/shinkai-node/src/tools/tool_execution/mod.rs +++ b/shinkai-bin/shinkai-node/src/tools/tool_execution/mod.rs @@ -1,4 +1,5 @@ pub mod execution_coordinator; pub mod execution_custom; pub mod execution_deno_dynamic; +pub mod execution_header_generator; pub mod execution_python_dynamic;