From 8ade0d02673644a27ffb342c37971318b6ab62c4 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Mon, 23 Dec 2024 21:07:32 -0300 Subject: [PATCH 1/5] Export / Import Agents Endpoints --- .../src/network/handle_commands_list.rs | 12 ++ .../src/network/v2_api/api_v2_commands.rs | 204 ++++++++++++++++++ .../src/api_v2/api_v2_handlers_general.rs | 116 ++++++++++ .../shinkai-http-api/src/node_commands.rs | 10 + 4 files changed, 342 insertions(+) diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 1f64d8296..82dd6b2c7 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -232,6 +232,18 @@ impl Node { .await; }); } + NodeCommand::V2ApiExportAgent { bearer, agent_id, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_export_agent(db_clone, bearer, agent_id, res).await; + }); + } + NodeCommand::V2ApiImportAgent { bearer, url, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_import_agent(db_clone, bearer, url, res).await; + }); + } NodeCommand::AvailableLLMProviders { full_profile_name, res } => { let db_clone = self.db.clone(); let node_name_clone = self.node_name.clone(); diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs index dfb8679bd..94d6377de 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs @@ -1,4 +1,11 @@ use std::{env, sync::Arc}; +use std::path::Path; +use std::fs::File; +use std::io::{Read, Write}; + +use serde_json::{json, Value}; +use tokio::fs; +use zip::{write::FileOptions, ZipWriter}; use async_channel::Sender; use ed25519_dalek::{SigningKey, VerifyingKey}; @@ -1591,4 +1598,201 @@ impl Node { Ok(()) } + + pub async fn v2_api_export_agent( + db: Arc, + bearer: String, + agent_id: String, + res: Sender, APIError>>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the agent from the database + match db.get_agent(&agent_id) { + Ok(Some(agent)) => { + // Serialize the agent to JSON bytes + let agent_bytes = match serde_json::to_vec(&agent) { + Ok(bytes) => bytes, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to serialize agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Create a temporary zip file + let name = format!("{}.zip", agent.agent_id.replace(':', "_")); + let path = Path::new(&name); + let file = match File::create(&path) { + Ok(file) => file, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + let mut zip = ZipWriter::new(file); + + // Add the agent JSON to the zip file + if let Err(err) = zip.start_file::<_, ()>("__agent.json", FileOptions::default()) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create agent file in zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + if let Err(err) = zip.write_all(&agent_bytes) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to write agent data to zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Finalize the zip file + if let Err(err) = zip.finish() { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to finalize zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Read the zip file into memory + match fs::read(&path).await { + Ok(file_bytes) => { + // Clean up the temporary file + if let Err(err) = fs::remove_file(&path).await { + eprintln!("Warning: Failed to remove temporary file: {}", err); + } + let _ = res.send(Ok(file_bytes)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to read zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: format!("Agent not found: {}", agent_id), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_import_agent( + db: Arc, + bearer: String, + url: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Download the zip file + let response = match reqwest::get(&url).await { + Ok(response) => response, + Err(err) => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to download agent from URL: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Get the bytes from the response + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to read response bytes: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + let bytes = bytes.to_vec(); + let buffer = tokio::task::spawn_blocking(move || { + let cursor = std::io::Cursor::new(bytes); + let mut archive = zip::ZipArchive::new(cursor)?; + let mut buffer = Vec::new(); + let mut agent_file = archive.by_name("__agent.json")?; + agent_file.read_to_end(&mut buffer)?; + Ok::<_, Box>(buffer) + }) + .await + .map_err(|e| NodeError::from(e.to_string()))? + .map_err(|e| NodeError::from(e.to_string()))?; + + // Parse the JSON into an Agent + let agent: Agent = serde_json::from_slice(&buffer).map_err(|e| NodeError::from(e.to_string()))?; + + // Save the agent to the database + match db.add_agent(agent.clone(), &agent.full_identity_name) { + Ok(_) => { + let response = json!({ + "status": "success", + "message": "Agent imported successfully", + "agent_id": agent.agent_id, + "agent": agent + }); + let _ = res.send(Ok(response)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Database Error".to_string(), + message: format!("Failed to save agent to database: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } } diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs index 801e80de7..d15655225 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs @@ -178,6 +178,20 @@ pub fn general_routes( .and(warp::header::("authorization")) .and_then(get_all_agents_handler); + let export_agent_route = warp::path("export_agent") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::query::>()) + .and_then(export_agent_handler); + + let import_agent_route = warp::path("import_agent") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(import_agent_handler); + let test_llm_provider_route = warp::path("test_llm_provider") .and(warp::post()) .and(with_sender(node_commands_sender.clone())) @@ -207,6 +221,8 @@ pub fn general_routes( .or(update_agent_route) .or(get_agent_route) .or(get_all_agents_route) + .or(export_agent_route) + .or(import_agent_route) .or(test_llm_provider_route) } @@ -934,6 +950,104 @@ pub async fn get_all_agents_handler( } } +#[utoipa::path( + get, + path = "/v2/export_agent", + params( + ("agent_id" = String, Query, description = "Agent identifier") + ), + responses( + (status = 200, description = "Exported agent", body = Vec), + (status = 400, description = "Invalid agent identifier", body = APIError), + ) +)] +pub async fn export_agent_handler( + sender: Sender, + authorization: String, + query_params: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + + let agent_id = query_params + .get("agent_id") + .ok_or_else(|| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid agent identifier".to_string(), + message: "Agent identifier is required".to_string(), + }) + })? + .to_string(); + + let (res_sender, res_receiver) = async_channel::bounded(1); + + sender + .send(NodeCommand::V2ApiExportAgent { + bearer, + agent_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(file_bytes) => { + // Return the raw bytes with appropriate headers + Ok(warp::reply::with_header( + warp::reply::with_status(file_bytes, StatusCode::OK), + "Content-Type", + "application/octet-stream", + )) + } + Err(error) => Ok(warp::reply::with_header( + warp::reply::with_status( + error.message.as_bytes().to_vec(), + StatusCode::from_u16(error.code).unwrap() + ), + "Content-Type", + "text/plain", + )) + } +} + +#[utoipa::path( + post, + path = "/v2/import_agent", + request_body = HashMap, + responses( + (status = 200, description = "Successfully imported agent", body = Value), + (status = 400, description = "Invalid URL or agent data", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn import_agent_handler( + sender: Sender, + authorization: String, + payload: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let url = payload.get("url").cloned().unwrap_or_default(); + + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiImportAgent { + bearer, + url, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => Ok(warp::reply::json(&response)), + Err(error) => Err(warp::reject::custom(error)), + } +} + #[utoipa::path( post, path = "/v2/test_llm_provider", @@ -990,6 +1104,8 @@ pub async fn test_llm_provider_handler( add_agent_handler, remove_agent_handler, update_agent_handler, + import_agent_handler, + export_agent_handler, get_agent_handler, get_all_agents_handler, test_llm_provider_handler, diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index a6d8bdb1a..f6798b042 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -255,6 +255,16 @@ pub enum NodeCommand { msg: ShinkaiMessage, res: Sender>, }, + V2ApiImportAgent { + bearer: String, + url: String, + res: Sender>, + }, + V2ApiExportAgent { + bearer: String, + agent_id: String, + res: Sender, APIError>>, + }, AvailableLLMProviders { full_profile_name: String, res: Sender, String>>, From 42d231259df86e45daa25a3a957b11e900867108 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 24 Dec 2024 10:39:45 -0300 Subject: [PATCH 2/5] Fix agent import, refactor zip download function. --- .../src/network/node_shareable_logic.rs | 81 +++++++++++++++++++ .../src/network/v2_api/api_v2_commands.rs | 42 ++-------- .../network/v2_api/api_v2_commands_tools.rs | 73 ++--------------- 3 files changed, 96 insertions(+), 100 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs b/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs index 8cedb4d3b..12bbbd7f4 100644 --- a/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs +++ b/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs @@ -1,7 +1,9 @@ use shinkai_http_api::node_api_router::APIError; use shinkai_message_primitives::schemas::identity::{Identity, StandardIdentityType}; use std::sync::Arc; +use std::io::Read; use tokio::sync::Mutex; +use tokio_util::bytes::Bytes; use crate::managers::identity_manager::IdentityManager; use crate::managers::identity_manager::IdentityManagerTrait; @@ -186,3 +188,82 @@ pub async fn validate_message_main_logic( Ok((msg, sender_subidentity)) } + +pub struct ZipFileContents { + pub buffer: Vec, + pub archive: zip::ZipArchive>, +} + +pub async fn download_zip_file(url: String, file_name: String) -> Result { + // Download the zip file + let response = match reqwest::get(&url).await { + Ok(response) => response, + Err(err) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to download agent from URL: {}", err), + }); + } + }; + + // Get the bytes from the response + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(err) => { + return Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to read response bytes: {}", err), + }); + } + }; + + // Create a cursor from the bytes + let cursor = std::io::Cursor::new(bytes.clone()); + + // Create a zip archive from the cursor + let mut archive = match zip::ZipArchive::new(cursor) { + Ok(archive) => archive, + Err(err) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid Zip File".to_string(), + message: format!("Failed to read zip archive: {}", err), + }); + } + }; + + // Extract and parse file + let mut buffer = Vec::new(); + { + let mut file = match archive.by_name(&file_name) { + Ok(file) => file, + Err(_) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid Zip File".to_string(), + message: format!("Archive does not contain {}", file_name), + }); + } + }; + + // Read the file contents into a buffer + if let Err(err) = file.read_to_end(&mut buffer) { + return Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Read Error".to_string(), + message: format!("Failed to read file contents: {}", err), + }); + } + } + + // Create a new cursor and archive for returning + let return_cursor = std::io::Cursor::new(bytes); + let return_archive = zip::ZipArchive::new(return_cursor).unwrap(); + + Ok(ZipFileContents { + buffer, + archive: return_archive, + }) +} diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs index 94d6377de..0867f3a43 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs @@ -1,7 +1,7 @@ use std::{env, sync::Arc}; use std::path::Path; use std::fs::File; -use std::io::{Read, Write}; +use std::io::Write; use serde_json::{json, Value}; use tokio::fs; @@ -50,7 +50,7 @@ use x25519_dalek::PublicKey as EncryptionPublicKey; use crate::{ llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper}, managers::{identity_manager::IdentityManagerTrait, IdentityManager}, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, Node, node_shareable_logic::download_zip_file}, tools::tool_generation, utils::update_global_identity::update_global_identity_name, }; @@ -1728,49 +1728,21 @@ impl Node { return Ok(()); } - // Download the zip file - let response = match reqwest::get(&url).await { - Ok(response) => response, + let zip_contents = match download_zip_file(url, "__agent.json".to_string()).await { + Ok(contents) => contents, Err(err) => { let api_error = APIError { code: StatusCode::BAD_REQUEST.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to download agent from URL: {}", err), + error: "Invalid Agent Zip".to_string(), + message: format!("Failed to extract agent.json: {:?}", err), }; let _ = res.send(Err(api_error)).await; return Ok(()); } }; - // Get the bytes from the response - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(err) => { - let api_error = APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to read response bytes: {}", err), - }; - let _ = res.send(Err(api_error)).await; - return Ok(()); - } - }; - - let bytes = bytes.to_vec(); - let buffer = tokio::task::spawn_blocking(move || { - let cursor = std::io::Cursor::new(bytes); - let mut archive = zip::ZipArchive::new(cursor)?; - let mut buffer = Vec::new(); - let mut agent_file = archive.by_name("__agent.json")?; - agent_file.read_to_end(&mut buffer)?; - Ok::<_, Box>(buffer) - }) - .await - .map_err(|e| NodeError::from(e.to_string()))? - .map_err(|e| NodeError::from(e.to_string()))?; - // Parse the JSON into an Agent - let agent: Agent = serde_json::from_slice(&buffer).map_err(|e| NodeError::from(e.to_string()))?; + let agent: Agent = serde_json::from_slice(&zip_contents.buffer).map_err(|e| NodeError::from(e.to_string()))?; // Save the agent to the database match db.add_agent(agent.clone(), &agent.full_identity_name) { diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index f7cbd698f..381edb925 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -1,7 +1,7 @@ use crate::{ llm_provider::job_manager::JobManager, managers::IdentityManager, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, tools::{ tool_definitions::definition_generation::{generate_tool_definitions, get_all_deno_tools}, tool_execution::execution_coordinator::{execute_code, execute_tool_cmd}, @@ -10,7 +10,6 @@ use crate::{ }, utils::environment::NodeEnvironment, }; -use std::io::Read; use async_channel::Sender; use ed25519_dalek::SigningKey; @@ -43,7 +42,7 @@ use shinkai_tools_primitives::tools::{ tool_playground::ToolPlayground, }; use shinkai_vector_fs::vector_fs::vector_fs::VectorFS; -use std::{fs::File, io::Write, path::Path, sync::Arc, time::Instant}; +use std::{fs::File, io::Write, io::Read, path::Path, sync::Arc, time::Instant}; use tokio::sync::Mutex; use zip::{write::FileOptions, ZipWriter}; @@ -1613,71 +1612,15 @@ impl Node { node_env: NodeEnvironment, url: String, ) -> Result { - // Download the zip file - let response = match reqwest::get(&url).await { - Ok(response) => response, + let mut zip_contents = match download_zip_file(url, "__tool.json".to_string()).await { + Ok(contents) => contents, Err(err) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to download tool from URL: {}", err), - }); - } - }; - - // Get the bytes from the response - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(err) => { - return Err(APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to read response bytes: {}", err), - }); - } - }; - - // Create a cursor from the bytes - let cursor = std::io::Cursor::new(bytes); - - // Create a zip archive from the cursor - let mut archive = match zip::ZipArchive::new(cursor) { - Ok(archive) => archive, - Err(err) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Invalid Zip File".to_string(), - message: format!("Failed to read zip archive: {}", err), - }); + return Err(err); } }; - // Extract and parse tool.json - let mut buffer = Vec::new(); - { - let mut tool_file = match archive.by_name("__tool.json") { - Ok(file) => file, - Err(_) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Invalid Tool Archive".to_string(), - message: "Archive does not contain tool.json".to_string(), - }); - } - }; - - // Read the tool file contents into a buffer - if let Err(err) = tool_file.read_to_end(&mut buffer) { - return Err(APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Read Error".to_string(), - message: format!("Failed to read tool.json contents: {}", err), - }); - } - } // `tool_file` goes out of scope here - // Parse the JSON into a ShinkaiTool - let tool: ShinkaiTool = match serde_json::from_slice(&buffer) { + let tool: ShinkaiTool = match serde_json::from_slice(&zip_contents.buffer) { Ok(tool) => tool, Err(err) => { return Err(APIError { @@ -1692,7 +1635,7 @@ impl Node { let mut db_write = db; match db_write.add_tool(tool).await { Ok(tool) => { - let archive_clone = archive.clone(); + let archive_clone = zip_contents.archive.clone(); let files = archive_clone.file_names(); for file in files { println!("File: {:?}", file); @@ -1701,7 +1644,7 @@ impl Node { } let mut buffer = Vec::new(); { - let file = archive.by_name(file); + let file = zip_contents.archive.by_name(file); let mut tool_file = match file { Ok(file) => file, Err(_) => { From ee11f039a93f833122f967b9e22a852ebb0e995f Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Tue, 24 Dec 2024 16:31:06 -0300 Subject: [PATCH 3/5] Add support to export and import crons. --- .../src/network/handle_commands_list.rs | 16 ++ .../network/v2_api/api_v2_commands_cron.rs | 208 +++++++++++++++++- .../src/api_v2/api_v2_handlers_cron.rs | 127 +++++++++++ .../shinkai-http-api/src/node_commands.rs | 10 + 4 files changed, 358 insertions(+), 3 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 82dd6b2c7..97cb0fbc0 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -2738,6 +2738,22 @@ impl Node { let _ = Node::v2_api_get_cron_task_logs(db_clone, bearer, cron_task_id, res).await; }); } + NodeCommand::V2ApiImportCronTask { bearer, url, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_import_cron_task(db_clone, bearer, url, res).await; + }); + } + NodeCommand::V2ApiExportCronTask { + bearer, + cron_task_id, + res, + } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_export_cron_task(db_clone, bearer, cron_task_id, res).await; + }); + } NodeCommand::V2ApiGenerateToolMetadataImplementation { bearer, job_id, diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs index 131ea88e7..dc30ecfbd 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs @@ -1,6 +1,6 @@ use crate::{ cron_tasks::cron_manager::CronManager, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, Node, node_shareable_logic::download_zip_file}, }; use async_channel::Sender; use reqwest::StatusCode; @@ -9,8 +9,13 @@ use shinkai_http_api::node_api_router::APIError; use shinkai_message_primitives::schemas::crontab::{CronTask, CronTaskAction}; use shinkai_sqlite::SqliteManager; use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; -use chrono::{Utc, Local}; +use tokio::sync::Mutex; +use chrono::Local; +use std::path::Path; +use std::fs::File; +use std::io::Write; +use tokio::fs; +use zip::{write::FileOptions, ZipWriter}; impl Node { pub async fn v2_api_add_cron_task( @@ -306,4 +311,201 @@ impl Node { } } } + + pub async fn v2_api_import_cron_task( + db: Arc, + bearer: String, + url: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Download and validate the zip file + let zip_contents = match download_zip_file(url, "__cron.json".to_string()).await { + Ok(contents) => contents, + Err(err) => { + let _ = res.send(Err(err)).await; + return Ok(()); + } + }; + + // Parse the JSON content + let cron_data: Value = match serde_json::from_slice(&zip_contents.buffer) { + Ok(data) => data, + Err(err) => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid JSON".to_string(), + message: format!("Failed to parse cron task JSON: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Extract and validate required fields + let cron_task = match cron_data.as_object() { + Some(obj) => { + let name = obj.get("name").and_then(|v| v.as_str()).ok_or_else(|| NodeError::from("Missing or invalid 'name' field".to_string()))?; + let cron = obj.get("cron").and_then(|v| v.as_str()).ok_or_else(|| NodeError::from("Missing or invalid 'cron' field".to_string()))?; + let action: CronTaskAction = serde_json::from_value(obj.get("action").cloned() + .ok_or_else(|| NodeError::from("Missing 'action' field".to_string()))?) + .map_err(|e| NodeError::from(format!("Invalid action format: {}", e)))?; + let description = obj.get("description").and_then(|v| v.as_str()).map(String::from); + + (name.to_string(), cron.to_string(), action, description) + } + None => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid JSON".to_string(), + message: "JSON must be an object".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Add the cron task to the database + match db.add_cron_task(&cron_task.0, cron_task.3.as_deref(), &cron_task.1, &cron_task.2) { + Ok(_) => { + let response = json!({ + "status": "success", + "message": "Cron task imported successfully" + }); + let _ = res.send(Ok(response)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Database Error".to_string(), + message: format!("Failed to add cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_export_cron_task( + db: Arc, + bearer: String, + cron_task_id: i64, + res: Sender, APIError>>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the cron task from the database + match db.get_cron_task(cron_task_id) { + Ok(Some(cron_task)) => { + // Serialize the cron task to JSON bytes + let cron_task_bytes = match serde_json::to_vec(&cron_task) { + Ok(bytes) => bytes, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to serialize cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Create a temporary zip file + let name = format!("cron_task_{}.zip", cron_task_id); + let path = Path::new(&name); + let file = match File::create(&path) { + Ok(file) => file, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + let mut zip = ZipWriter::new(file); + + // Add the cron task JSON to the zip file + if let Err(err) = zip.start_file::<_, ()>("__cron_task.json", FileOptions::default()) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create cron task file in zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + if let Err(err) = zip.write_all(&cron_task_bytes) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to write cron task data to zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Finalize the zip file + if let Err(err) = zip.finish() { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to finalize zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Read the zip file into memory + match fs::read(&path).await { + Ok(file_bytes) => { + // Clean up the temporary file + if let Err(err) = fs::remove_file(&path).await { + eprintln!("Warning: Failed to remove temporary file: {}", err); + } + let _ = res.send(Ok(file_bytes)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to read zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: format!("Cron task not found: {}", cron_task_id), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } } diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs index 76a8cb4e2..6b2567860 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs @@ -67,6 +67,20 @@ pub fn cron_routes( .and(warp::header::("authorization")) .and_then(get_cron_schedule_handler); + let import_cron_task_route = warp::path("import_cron_task") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(import_cron_task_handler); + + let export_cron_task_route = warp::path("export_cron_task") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::query::>()) + .and_then(export_cron_task_handler); + add_cron_task_route .or(list_all_cron_tasks_route) .or(get_specific_cron_task_route) @@ -75,6 +89,8 @@ pub fn cron_routes( .or(update_cron_task_route) .or(force_execute_cron_task_route) .or(get_cron_schedule_route) + .or(import_cron_task_route) + .or(export_cron_task_route) } #[derive(Deserialize)] @@ -506,6 +522,115 @@ pub async fn get_cron_schedule_handler( } } +#[utoipa::path( + post, + path = "/v2/import_cron_task", + request_body = ImportCronTaskRequest, + responses( + (status = 200, description = "Successfully imported cron task", body = Value), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn import_cron_task_handler( + node_commands_sender: Sender, + authorization: String, + url: String, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let (res_sender, res_receiver) = async_channel::bounded(1); + + node_commands_sender + .send(NodeCommand::V2ApiImportCronTask { + bearer, + url, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => { + let response = create_success_response(response); + Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK)) + } + Err(error) => Ok(warp::reply::with_status( + warp::reply::json(&error), + StatusCode::from_u16(error.code).unwrap(), + )), + } +} + +#[utoipa::path( + get, + path = "/v2/export_cron_task", + params( + ("cron_task_id" = String, Query, description = "Cron task ID to export") + ), + responses( + (status = 200, description = "Successfully exported cron task", body = Vec), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn export_cron_task_handler( + node_commands_sender: Sender, + authorization: String, + query_params: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + + let cron_task_id_str = query_params.get("cron_task_id").ok_or_else(|| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid Query".to_string(), + message: "The cron_task_id parameter is required".to_string(), + }) + })?; + + // Parse cron_task_id to i64 + let cron_task_id: i64 = cron_task_id_str.parse().map_err(|_| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid Query".to_string(), + message: "The cron_task_id must be a valid integer.".to_string(), + }) + })?; + + let (res_sender, res_receiver) = async_channel::bounded(1); + node_commands_sender + .send(NodeCommand::V2ApiExportCronTask { + bearer, + cron_task_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(file_bytes) => { + // Return the raw bytes with appropriate headers + Ok(warp::reply::with_header( + warp::reply::with_status(file_bytes, StatusCode::OK), + "Content-Type", + "application/octet-stream", + )) + } + Err(error) => Ok(warp::reply::with_header( + warp::reply::with_status( + error.message.as_bytes().to_vec(), + StatusCode::from_u16(error.code).unwrap() + ), + "Content-Type", + "text/plain", + )) + } +} + #[derive(OpenApi)] #[openapi( paths( @@ -517,6 +642,8 @@ pub async fn get_cron_schedule_handler( update_cron_task_handler, force_execute_cron_task_handler, get_cron_schedule_handler, + import_cron_task_handler, + export_cron_task_handler, ), components( schemas(CronTask, CronTaskAction, APIError) diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index f6798b042..b908757a5 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -1140,4 +1140,14 @@ pub enum NodeCommand { file_name: String, res: Sender>, }, + V2ApiImportCronTask { + bearer: String, + url: String, + res: Sender>, + }, + V2ApiExportCronTask { + bearer: String, + cron_task_id: i64, + res: Sender, APIError>>, + }, } From a8b90f44e1d032e39a806b207042c9a392991fa5 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Thu, 26 Dec 2024 11:10:02 -0300 Subject: [PATCH 4/5] Fixes cron task importer --- .../shinkai-node/src/network/v2_api/api_v2_commands_cron.rs | 4 +++- .../shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs | 5 +++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs index dc30ecfbd..caad09610 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs @@ -324,7 +324,7 @@ impl Node { } // Download and validate the zip file - let zip_contents = match download_zip_file(url, "__cron.json".to_string()).await { + let zip_contents = match download_zip_file(url, "__cron_task.json".to_string()).await { Ok(contents) => contents, Err(err) => { let _ = res.send(Err(err)).await; @@ -369,6 +369,8 @@ impl Node { } }; + println!("cron_task: {:?}", cron_task); + // Add the cron task to the database match db.add_cron_task(&cron_task.0, cron_task.3.as_deref(), &cron_task.1, &cron_task.2) { Ok(_) => { diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs index 6b2567860..c74f29577 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs @@ -535,11 +535,12 @@ pub async fn get_cron_schedule_handler( pub async fn import_cron_task_handler( node_commands_sender: Sender, authorization: String, - url: String, + payload: HashMap, ) -> Result { let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let url = payload.get("url").cloned().unwrap_or_default(); + let (res_sender, res_receiver) = async_channel::bounded(1); - node_commands_sender .send(NodeCommand::V2ApiImportCronTask { bearer, From 7894b80ffe29d6d830430270a71531da6e9453d3 Mon Sep 17 00:00:00 2001 From: Guillermo Valin Date: Thu, 26 Dec 2024 12:01:49 -0300 Subject: [PATCH 5/5] Moved temp files from import to system's temp directory. --- shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs | 2 +- .../shinkai-node/src/network/v2_api/api_v2_commands_cron.rs | 2 +- .../shinkai-node/src/network/v2_api/api_v2_commands_tools.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs index 0867f3a43..35694df6e 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs @@ -1629,7 +1629,7 @@ impl Node { // Create a temporary zip file let name = format!("{}.zip", agent.agent_id.replace(':', "_")); - let path = Path::new(&name); + let path = std::env::temp_dir().join(&name); let file = match File::create(&path) { Ok(file) => file, Err(err) => { diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs index caad09610..c88da7f3f 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs @@ -423,7 +423,7 @@ impl Node { // Create a temporary zip file let name = format!("cron_task_{}.zip", cron_task_id); - let path = Path::new(&name); + let path = std::env::temp_dir().join(&name); let file = match File::create(&path) { Ok(file) => file, Err(err) => { diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index a21fb92f5..233ce2cd2 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -1538,7 +1538,7 @@ impl Node { let tool_bytes = serde_json::to_vec(&tool).unwrap(); let name = format!("{}.zip", tool.tool_router_key().replace(':', "_")); - let path = Path::new(&name); + let path = std::env::temp_dir().join(&name); let file = File::create(&path).map_err(|e| NodeError::from(e.to_string()))?; let mut zip = ZipWriter::new(file);