Skip to content

Commit

Permalink
oauth for tool execution in chat/api
Browse files Browse the repository at this point in the history
  • Loading branch information
acedward committed Dec 18, 2024
1 parent bf86fa9 commit d445939
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 99 deletions.
97 changes: 50 additions & 47 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Instant;
use crate::llm_provider::error::LLMProviderError;
use crate::llm_provider::execution::chains::inference_chain_trait::{FunctionCall, InferenceChainContextTrait};
use crate::tools::tool_definitions::definition_generation::{generate_tool_definitions, get_rust_tools};
use crate::tools::tool_execution::execution_header_generator::generate_execution_environment;
use crate::utils::environment::fetch_node_environment;
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -268,7 +269,10 @@ impl ToolRouter {

// Check if ADD_TESTING_NETWORK_ECHO is set
if std::env::var("ADD_TESTING_NETWORK_ECHO").unwrap_or_else(|_| "false".to_string()) == "true" {
match self.sqlite_manager.get_tool_by_key("local:::shinkai-tool-echo:::shinkai__echo") {
match self
.sqlite_manager
.get_tool_by_key("local:::shinkai-tool-echo:::shinkai__echo")
{
Ok(shinkai_tool) => {
if let ShinkaiTool::Deno(mut js_tool, _) = shinkai_tool {
js_tool.name = "network__echo".to_string();
Expand Down Expand Up @@ -522,24 +526,19 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false)
.await
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;
let mut envs = HashMap::new();

let bearer = context.db().read_api_v2_key().unwrap_or_default().unwrap_or_default();
let llm_provider = context.agent().clone().get_id().to_string();
envs.insert("BEARER".to_string(), bearer);
envs.insert(
"X_SHINKAI_TOOL_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert(
"X_SHINKAI_APP_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert(
"X_SHINKAI_INSTANCE_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider);

let envs = generate_execution_environment(
context.db(),
context.agent().clone().get_id().to_string(),
format!("jid-{}", tool_id),
format!("jid-{}", app_id),
shinkai_tool.tool_router_key().clone(),
format!("jid-{}", app_id),
&python_tool.oauth,
)
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;

let result = python_tool
.run(
envs,
Expand Down Expand Up @@ -597,27 +596,18 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false)
.await
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;
let mut envs = HashMap::new();
let bearer = context
.db()
.read_api_v2_key()
.unwrap_or_default()
.unwrap_or_default();
let llm_provider = context.agent().clone().get_id().to_string();
envs.insert("BEARER".to_string(), bearer);
envs.insert(
"X_SHINKAI_TOOL_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert(
"X_SHINKAI_APP_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert(
"X_SHINKAI_INSTANCE_ID".to_string(),
format!("jid-{}", context.full_job().job_id()),
);
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider);
let envs = generate_execution_environment(
context.db(),
context.agent().clone().get_id().to_string(),
format!("jid-{}", app_id),
format!("jid-{}", tool_id),
shinkai_tool.tool_router_key().clone(),
format!("jid-{}", app_id),
&deno_tool.oauth,
)
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;

let result = deno_tool
.run(
envs,
Expand Down Expand Up @@ -945,16 +935,29 @@ async def run(c: CONFIG, p: INPUTS) -> OUTPUT:
generate_tool_definitions(tools, CodeLanguage::Typescript, self.sqlite_manager.clone(), false)
.await
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;
let mut envs = HashMap::new();
envs.insert("BEARER".to_string(), "".to_string()); // TODO (How do we get the bearer?)
envs.insert("X_SHINKAI_TOOL_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_APP_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), "".to_string()); // TODO Pass data from the API

let oauth = match shinkai_tool.clone() {
ShinkaiTool::Deno(deno_tool, _) => deno_tool.oauth.clone(),
ShinkaiTool::Python(python_tool, _) => python_tool.oauth.clone(),
_ => return Err(LLMProviderError::FunctionNotFound(js_tool_name.to_string())),
};

let env = generate_execution_environment(
self.sqlite_manager.clone(),
"".to_string(),
format!("xid-{}", app_id),
format!("xid-{}", tool_id),
shinkai_tool.tool_router_key().clone(),
// TODO: Pass data from the API
"".to_string(),
&oauth,
)
.await
.map_err(|e| ToolError::ExecutionError(e.to_string()))?;

let result = js_tool
.run(
HashMap::new(),
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::llm_provider::job_manager::JobManager;
use crate::tools::tool_definitions::definition_generation::generate_tool_definitions;
use crate::tools::tool_execution::execution_custom::execute_custom_tool;
use crate::tools::tool_execution::execution_deno_dynamic::{check_deno_tool, execute_deno_tool};
use crate::tools::tool_execution::execution_header_generator::generate_execution_environment;
use crate::tools::tool_execution::execution_python_dynamic::execute_python_tool;
use crate::utils::environment::fetch_node_environment;

Expand Down Expand Up @@ -99,10 +100,10 @@ pub async fn handle_oauth(
let oauth_login_url = format!(
"{}?client_id={}&redirect_uri={}&scope={}&state={}",
o.authorization_url,
o.client_id,
urlencoding::encode(&o.client_id),
urlencoding::encode(&o.redirect_url),
o.scopes.join(" "),
uuid
urlencoding::encode(&o.scopes.join(" ")),
urlencoding::encode(&uuid)
);

return Err(ToolError::OAuthError(oauth_login_url));
Expand Down Expand Up @@ -157,26 +158,19 @@ pub async fn execute_tool_cmd(
.get_tool_by_key(&tool_router_key)
.map_err(|e| ToolError::ExecutionError(format!("Failed to get tool: {}", e)))?;

let mut envs = HashMap::new();
envs.insert("BEARER".to_string(), bearer);
envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone());
envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone());
envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone());

match tool {
ShinkaiTool::Python(python_tool, _) => {
let oauth = handle_oauth(
&python_tool.oauth,
&db,
let env = generate_execution_environment(
db.clone(),
llm_provider.clone(),
app_id.clone(),
tool_id.clone(),
tool_router_key.clone(),
"".to_string(), // TODO Pass data from the API
&python_tool.oauth.clone(),
)
.await?;

envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string());

let node_env = fetch_node_environment();
let node_storage_path = node_env
.node_storage_path
Expand All @@ -192,7 +186,7 @@ pub async fn execute_tool_cmd(
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;
python_tool
.run(
envs,
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
Expand All @@ -208,17 +202,17 @@ pub async fn execute_tool_cmd(
.map(|result| json!(result.data))
}
ShinkaiTool::Deno(deno_tool, _) => {
let oauth = handle_oauth(
&deno_tool.oauth,
&db,
let env = generate_execution_environment(
db.clone(),
llm_provider.clone(),
app_id.clone(),
tool_id.clone(),
tool_router_key.clone(),
"".to_string(), // TODO Pass data from the API
&deno_tool.oauth.clone(),
)
.await?;

envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string());

let node_env = fetch_node_environment();
let node_storage_path = node_env
.node_storage_path
Expand All @@ -234,7 +228,7 @@ pub async fn execute_tool_cmd(
.map_err(|_| ToolError::ExecutionError("Failed to generate tool definitions".to_string()))?;
deno_tool
.run(
envs,
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ use shinkai_tools_primitives::tools::parameters::Parameters;
use shinkai_tools_primitives::tools::tool_config::{OAuth, ToolConfig};
use shinkai_tools_primitives::tools::tool_output_arg::ToolOutputArg;

use super::execution_coordinator::handle_oauth;
use super::execution_header_generator::generate_execution_environment;
use crate::utils::environment::fetch_node_environment;
use shinkai_sqlite::SqliteManager;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};

pub async fn execute_deno_tool(
bearer: String,
_bearer: String,
db: Arc<SqliteManager>,
node_name: ShinkaiName,
parameters: Map<String, Value>,
Expand Down Expand Up @@ -50,22 +49,17 @@ pub async fn execute_deno_tool(
assets: None,
};

let mut envs = HashMap::new();
envs.insert("BEARER".to_string(), bearer);
envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone());
envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone());
envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone());

let oauth = handle_oauth(
&oauth.clone(),
&db,
let env = generate_execution_environment(
db.clone(),
llm_provider.clone(),
app_id.clone(),
tool_id.clone(),
// TODO: Update this value for runtime tool execution
"code-execution".to_string(),
"".to_string(),
&oauth.clone(),
)
.await?;
envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string());

let node_env = fetch_node_environment();
let node_storage_path = node_env
Expand Down Expand Up @@ -95,7 +89,7 @@ pub async fn execute_deno_tool(
}

match tool.run_on_demand(
envs,
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::{collections::HashMap, sync::Arc};

use shinkai_sqlite::SqliteManager;
use shinkai_tools_primitives::tools::{error::ToolError, tool_config::OAuth};

use super::execution_coordinator::handle_oauth;

pub async fn generate_execution_environment(
db: Arc<SqliteManager>,
llm_provider: String,
app_id: String,
tool_id: String,
tool_router_key: String,
instance_id: String,
oauth: &Option<Vec<OAuth>>,
) -> Result<HashMap<String, String>, ToolError> {
let mut envs = HashMap::new();

let bearer = db.read_api_v2_key().unwrap_or_default().unwrap_or_default();
envs.insert("BEARER".to_string(), bearer);
envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone());
envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone());
envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), instance_id.clone());
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider);

let oauth = handle_oauth(oauth, &db, app_id.clone(), tool_id.clone(), tool_router_key.clone()).await?;

envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string());

Ok(envs)
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{collections::HashMap, path::PathBuf};

use super::execution_coordinator::handle_oauth;
use super::execution_header_generator::generate_execution_environment;
use crate::utils::environment::fetch_node_environment;
use serde_json::{Map, Value};
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
Expand All @@ -14,10 +14,9 @@ use shinkai_tools_primitives::tools::{
tool_output_arg::ToolOutputArg,
};
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};

pub async fn execute_python_tool(
bearer: String,
_bearer: String,
db: Arc<SqliteManager>,
node_name: ShinkaiName,
parameters: Map<String, Value>,
Expand Down Expand Up @@ -51,22 +50,16 @@ pub async fn execute_python_tool(
assets: None,
};

let mut envs = HashMap::new();
envs.insert("BEARER".to_string(), bearer);
envs.insert("X_SHINKAI_TOOL_ID".to_string(), tool_id.clone());
envs.insert("X_SHINKAI_APP_ID".to_string(), app_id.clone());
envs.insert("X_SHINKAI_INSTANCE_ID".to_string(), "".to_string()); // TODO Pass data from the API
envs.insert("X_SHINKAI_LLM_PROVIDER".to_string(), llm_provider.clone());

let oauth = handle_oauth(
&oauth.clone(),
&db,
let env = generate_execution_environment(
db.clone(),
llm_provider.clone(),
app_id.clone(),
tool_id.clone(),
"code-execution".to_string(),
"".to_string(),
&oauth.clone(),
)
.await?;
envs.insert("SHINKAI_OAUTH".to_string(), oauth.to_string());

let node_env = fetch_node_environment();
let node_storage_path = node_env
Expand Down Expand Up @@ -96,7 +89,7 @@ pub async fn execute_python_tool(
}

match tool.run_on_demand(
envs,
env,
node_env.api_listen_address.ip().to_string(),
node_env.api_listen_address.port(),
support_files,
Expand Down
1 change: 1 addition & 0 deletions shinkai-bin/shinkai-node/src/tools/tool_execution/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod execution_coordinator;
pub mod execution_custom;
pub mod execution_deno_dynamic;
pub mod execution_header_generator;
pub mod execution_python_dynamic;

0 comments on commit d445939

Please sign in to comment.