diff --git a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs index 6b7cc45ed..74e23d36a 100644 --- a/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs +++ b/shinkai-bin/shinkai-node/src/network/handle_commands_list.rs @@ -232,6 +232,18 @@ impl Node { .await; }); } + NodeCommand::V2ApiExportAgent { bearer, agent_id, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_export_agent(db_clone, bearer, agent_id, res).await; + }); + } + NodeCommand::V2ApiImportAgent { bearer, url, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_import_agent(db_clone, bearer, url, res).await; + }); + } NodeCommand::AvailableLLMProviders { full_profile_name, res } => { let db_clone = self.db.clone(); let node_name_clone = self.node_name.clone(); @@ -2728,6 +2740,22 @@ impl Node { let _ = Node::v2_api_get_cron_task_logs(db_clone, bearer, cron_task_id, res).await; }); } + NodeCommand::V2ApiImportCronTask { bearer, url, res } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_import_cron_task(db_clone, bearer, url, res).await; + }); + } + NodeCommand::V2ApiExportCronTask { + bearer, + cron_task_id, + res, + } => { + let db_clone = Arc::clone(&self.db); + tokio::spawn(async move { + let _ = Node::v2_api_export_cron_task(db_clone, bearer, cron_task_id, res).await; + }); + } NodeCommand::V2ApiGenerateToolMetadataImplementation { bearer, job_id, diff --git a/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs b/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs index 8cedb4d3b..12bbbd7f4 100644 --- a/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs +++ b/shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs @@ -1,7 +1,9 @@ use shinkai_http_api::node_api_router::APIError; use shinkai_message_primitives::schemas::identity::{Identity, StandardIdentityType}; use std::sync::Arc; +use std::io::Read; use tokio::sync::Mutex; +use tokio_util::bytes::Bytes; use crate::managers::identity_manager::IdentityManager; use crate::managers::identity_manager::IdentityManagerTrait; @@ -186,3 +188,82 @@ pub async fn validate_message_main_logic( Ok((msg, sender_subidentity)) } + +pub struct ZipFileContents { + pub buffer: Vec, + pub archive: zip::ZipArchive>, +} + +pub async fn download_zip_file(url: String, file_name: String) -> Result { + // Download the zip file + let response = match reqwest::get(&url).await { + Ok(response) => response, + Err(err) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to download agent from URL: {}", err), + }); + } + }; + + // Get the bytes from the response + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(err) => { + return Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Download Failed".to_string(), + message: format!("Failed to read response bytes: {}", err), + }); + } + }; + + // Create a cursor from the bytes + let cursor = std::io::Cursor::new(bytes.clone()); + + // Create a zip archive from the cursor + let mut archive = match zip::ZipArchive::new(cursor) { + Ok(archive) => archive, + Err(err) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid Zip File".to_string(), + message: format!("Failed to read zip archive: {}", err), + }); + } + }; + + // Extract and parse file + let mut buffer = Vec::new(); + { + let mut file = match archive.by_name(&file_name) { + Ok(file) => file, + Err(_) => { + return Err(APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid Zip File".to_string(), + message: format!("Archive does not contain {}", file_name), + }); + } + }; + + // Read the file contents into a buffer + if let Err(err) = file.read_to_end(&mut buffer) { + return Err(APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Read Error".to_string(), + message: format!("Failed to read file contents: {}", err), + }); + } + } + + // Create a new cursor and archive for returning + let return_cursor = std::io::Cursor::new(bytes); + let return_archive = zip::ZipArchive::new(return_cursor).unwrap(); + + Ok(ZipFileContents { + buffer, + archive: return_archive, + }) +} diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs index dfb8679bd..35694df6e 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs @@ -1,4 +1,11 @@ use std::{env, sync::Arc}; +use std::path::Path; +use std::fs::File; +use std::io::Write; + +use serde_json::{json, Value}; +use tokio::fs; +use zip::{write::FileOptions, ZipWriter}; use async_channel::Sender; use ed25519_dalek::{SigningKey, VerifyingKey}; @@ -43,7 +50,7 @@ use x25519_dalek::PublicKey as EncryptionPublicKey; use crate::{ llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper}, managers::{identity_manager::IdentityManagerTrait, IdentityManager}, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, Node, node_shareable_logic::download_zip_file}, tools::tool_generation, utils::update_global_identity::update_global_identity_name, }; @@ -1591,4 +1598,173 @@ impl Node { Ok(()) } + + pub async fn v2_api_export_agent( + db: Arc, + bearer: String, + agent_id: String, + res: Sender, APIError>>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the agent from the database + match db.get_agent(&agent_id) { + Ok(Some(agent)) => { + // Serialize the agent to JSON bytes + let agent_bytes = match serde_json::to_vec(&agent) { + Ok(bytes) => bytes, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to serialize agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Create a temporary zip file + let name = format!("{}.zip", agent.agent_id.replace(':', "_")); + let path = std::env::temp_dir().join(&name); + let file = match File::create(&path) { + Ok(file) => file, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + let mut zip = ZipWriter::new(file); + + // Add the agent JSON to the zip file + if let Err(err) = zip.start_file::<_, ()>("__agent.json", FileOptions::default()) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create agent file in zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + if let Err(err) = zip.write_all(&agent_bytes) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to write agent data to zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Finalize the zip file + if let Err(err) = zip.finish() { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to finalize zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Read the zip file into memory + match fs::read(&path).await { + Ok(file_bytes) => { + // Clean up the temporary file + if let Err(err) = fs::remove_file(&path).await { + eprintln!("Warning: Failed to remove temporary file: {}", err); + } + let _ = res.send(Ok(file_bytes)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to read zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: format!("Agent not found: {}", agent_id), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve agent: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_import_agent( + db: Arc, + bearer: String, + url: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + let zip_contents = match download_zip_file(url, "__agent.json".to_string()).await { + Ok(contents) => contents, + Err(err) => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid Agent Zip".to_string(), + message: format!("Failed to extract agent.json: {:?}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Parse the JSON into an Agent + let agent: Agent = serde_json::from_slice(&zip_contents.buffer).map_err(|e| NodeError::from(e.to_string()))?; + + // Save the agent to the database + match db.add_agent(agent.clone(), &agent.full_identity_name) { + Ok(_) => { + let response = json!({ + "status": "success", + "message": "Agent imported successfully", + "agent_id": agent.agent_id, + "agent": agent + }); + let _ = res.send(Ok(response)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Database Error".to_string(), + message: format!("Failed to save agent to database: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs index 131ea88e7..c88da7f3f 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_cron.rs @@ -1,6 +1,6 @@ use crate::{ cron_tasks::cron_manager::CronManager, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, Node, node_shareable_logic::download_zip_file}, }; use async_channel::Sender; use reqwest::StatusCode; @@ -9,8 +9,13 @@ use shinkai_http_api::node_api_router::APIError; use shinkai_message_primitives::schemas::crontab::{CronTask, CronTaskAction}; use shinkai_sqlite::SqliteManager; use std::sync::Arc; -use tokio::sync::{Mutex, RwLock}; -use chrono::{Utc, Local}; +use tokio::sync::Mutex; +use chrono::Local; +use std::path::Path; +use std::fs::File; +use std::io::Write; +use tokio::fs; +use zip::{write::FileOptions, ZipWriter}; impl Node { pub async fn v2_api_add_cron_task( @@ -306,4 +311,203 @@ impl Node { } } } + + pub async fn v2_api_import_cron_task( + db: Arc, + bearer: String, + url: String, + res: Sender>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Download and validate the zip file + let zip_contents = match download_zip_file(url, "__cron_task.json".to_string()).await { + Ok(contents) => contents, + Err(err) => { + let _ = res.send(Err(err)).await; + return Ok(()); + } + }; + + // Parse the JSON content + let cron_data: Value = match serde_json::from_slice(&zip_contents.buffer) { + Ok(data) => data, + Err(err) => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid JSON".to_string(), + message: format!("Failed to parse cron task JSON: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Extract and validate required fields + let cron_task = match cron_data.as_object() { + Some(obj) => { + let name = obj.get("name").and_then(|v| v.as_str()).ok_or_else(|| NodeError::from("Missing or invalid 'name' field".to_string()))?; + let cron = obj.get("cron").and_then(|v| v.as_str()).ok_or_else(|| NodeError::from("Missing or invalid 'cron' field".to_string()))?; + let action: CronTaskAction = serde_json::from_value(obj.get("action").cloned() + .ok_or_else(|| NodeError::from("Missing 'action' field".to_string()))?) + .map_err(|e| NodeError::from(format!("Invalid action format: {}", e)))?; + let description = obj.get("description").and_then(|v| v.as_str()).map(String::from); + + (name.to_string(), cron.to_string(), action, description) + } + None => { + let api_error = APIError { + code: StatusCode::BAD_REQUEST.as_u16(), + error: "Invalid JSON".to_string(), + message: "JSON must be an object".to_string(), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + println!("cron_task: {:?}", cron_task); + + // Add the cron task to the database + match db.add_cron_task(&cron_task.0, cron_task.3.as_deref(), &cron_task.1, &cron_task.2) { + Ok(_) => { + let response = json!({ + "status": "success", + "message": "Cron task imported successfully" + }); + let _ = res.send(Ok(response)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Database Error".to_string(), + message: format!("Failed to add cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } + + pub async fn v2_api_export_cron_task( + db: Arc, + bearer: String, + cron_task_id: i64, + res: Sender, APIError>>, + ) -> Result<(), NodeError> { + // Validate the bearer token + if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() { + return Ok(()); + } + + // Retrieve the cron task from the database + match db.get_cron_task(cron_task_id) { + Ok(Some(cron_task)) => { + // Serialize the cron task to JSON bytes + let cron_task_bytes = match serde_json::to_vec(&cron_task) { + Ok(bytes) => bytes, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to serialize cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + // Create a temporary zip file + let name = format!("cron_task_{}.zip", cron_task_id); + let path = std::env::temp_dir().join(&name); + let file = match File::create(&path) { + Ok(file) => file, + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + }; + + let mut zip = ZipWriter::new(file); + + // Add the cron task JSON to the zip file + if let Err(err) = zip.start_file::<_, ()>("__cron_task.json", FileOptions::default()) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to create cron task file in zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + if let Err(err) = zip.write_all(&cron_task_bytes) { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to write cron task data to zip: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Finalize the zip file + if let Err(err) = zip.finish() { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to finalize zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + return Ok(()); + } + + // Read the zip file into memory + match fs::read(&path).await { + Ok(file_bytes) => { + // Clean up the temporary file + if let Err(err) = fs::remove_file(&path).await { + eprintln!("Warning: Failed to remove temporary file: {}", err); + } + let _ = res.send(Ok(file_bytes)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to read zip file: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + } + Ok(None) => { + let api_error = APIError { + code: StatusCode::NOT_FOUND.as_u16(), + error: "Not Found".to_string(), + message: format!("Cron task not found: {}", cron_task_id), + }; + let _ = res.send(Err(api_error)).await; + } + Err(err) => { + let api_error = APIError { + code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), + error: "Internal Server Error".to_string(), + message: format!("Failed to retrieve cron task: {}", err), + }; + let _ = res.send(Err(api_error)).await; + } + } + + Ok(()) + } } diff --git a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs index 48b8a4e0b..78d36d59d 100644 --- a/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs +++ b/shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands_tools.rs @@ -1,7 +1,7 @@ use crate::{ llm_provider::job_manager::JobManager, managers::IdentityManager, - network::{node_error::NodeError, Node}, + network::{node_error::NodeError, node_shareable_logic::download_zip_file, Node}, tools::{ tool_definitions::definition_generation::{generate_tool_definitions, get_all_deno_tools}, tool_execution::execution_coordinator::{execute_code, execute_tool_cmd}, @@ -10,7 +10,6 @@ use crate::{ }, utils::environment::NodeEnvironment, }; -use std::io::Read; use async_channel::Sender; use ed25519_dalek::SigningKey; @@ -43,7 +42,7 @@ use shinkai_tools_primitives::tools::{ tool_playground::ToolPlayground, }; use shinkai_vector_fs::vector_fs::vector_fs::VectorFS; -use std::{fs::File, io::Write, path::Path, sync::Arc, time::Instant}; +use std::{fs::File, io::Write, io::Read, path::Path, sync::Arc, time::Instant}; use tokio::sync::Mutex; use zip::{write::FileOptions, ZipWriter}; @@ -1571,7 +1570,7 @@ impl Node { let tool_bytes = serde_json::to_vec(&tool).unwrap(); let name = format!("{}.zip", tool.tool_router_key().replace(':', "_")); - let path = Path::new(&name); + let path = std::env::temp_dir().join(&name); let file = File::create(&path).map_err(|e| NodeError::from(e.to_string()))?; let mut zip = ZipWriter::new(file); @@ -1649,71 +1648,15 @@ impl Node { node_env: NodeEnvironment, url: String, ) -> Result { - // Download the zip file - let response = match reqwest::get(&url).await { - Ok(response) => response, + let mut zip_contents = match download_zip_file(url, "__tool.json".to_string()).await { + Ok(contents) => contents, Err(err) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to download tool from URL: {}", err), - }); - } - }; - - // Get the bytes from the response - let bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(err) => { - return Err(APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Download Failed".to_string(), - message: format!("Failed to read response bytes: {}", err), - }); - } - }; - - // Create a cursor from the bytes - let cursor = std::io::Cursor::new(bytes); - - // Create a zip archive from the cursor - let mut archive = match zip::ZipArchive::new(cursor) { - Ok(archive) => archive, - Err(err) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Invalid Zip File".to_string(), - message: format!("Failed to read zip archive: {}", err), - }); + return Err(err); } }; - // Extract and parse tool.json - let mut buffer = Vec::new(); - { - let mut tool_file = match archive.by_name("__tool.json") { - Ok(file) => file, - Err(_) => { - return Err(APIError { - code: StatusCode::BAD_REQUEST.as_u16(), - error: "Invalid Tool Archive".to_string(), - message: "Archive does not contain tool.json".to_string(), - }); - } - }; - - // Read the tool file contents into a buffer - if let Err(err) = tool_file.read_to_end(&mut buffer) { - return Err(APIError { - code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(), - error: "Read Error".to_string(), - message: format!("Failed to read tool.json contents: {}", err), - }); - } - } // `tool_file` goes out of scope here - // Parse the JSON into a ShinkaiTool - let tool: ShinkaiTool = match serde_json::from_slice(&buffer) { + let tool: ShinkaiTool = match serde_json::from_slice(&zip_contents.buffer) { Ok(tool) => tool, Err(err) => { return Err(APIError { @@ -1728,7 +1671,7 @@ impl Node { let db_write = db; match db_write.add_tool(tool).await { Ok(tool) => { - let archive_clone = archive.clone(); + let archive_clone = zip_contents.archive.clone(); let files = archive_clone.file_names(); for file in files { println!("File: {:?}", file); @@ -1737,7 +1680,7 @@ impl Node { } let mut buffer = Vec::new(); { - let file = archive.by_name(file); + let file = zip_contents.archive.by_name(file); let mut tool_file = match file { Ok(file) => file, Err(_) => { diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs index 76a8cb4e2..c74f29577 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_cron.rs @@ -67,6 +67,20 @@ pub fn cron_routes( .and(warp::header::("authorization")) .and_then(get_cron_schedule_handler); + let import_cron_task_route = warp::path("import_cron_task") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(import_cron_task_handler); + + let export_cron_task_route = warp::path("export_cron_task") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::query::>()) + .and_then(export_cron_task_handler); + add_cron_task_route .or(list_all_cron_tasks_route) .or(get_specific_cron_task_route) @@ -75,6 +89,8 @@ pub fn cron_routes( .or(update_cron_task_route) .or(force_execute_cron_task_route) .or(get_cron_schedule_route) + .or(import_cron_task_route) + .or(export_cron_task_route) } #[derive(Deserialize)] @@ -506,6 +522,116 @@ pub async fn get_cron_schedule_handler( } } +#[utoipa::path( + post, + path = "/v2/import_cron_task", + request_body = ImportCronTaskRequest, + responses( + (status = 200, description = "Successfully imported cron task", body = Value), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn import_cron_task_handler( + node_commands_sender: Sender, + authorization: String, + payload: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let url = payload.get("url").cloned().unwrap_or_default(); + + let (res_sender, res_receiver) = async_channel::bounded(1); + node_commands_sender + .send(NodeCommand::V2ApiImportCronTask { + bearer, + url, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => { + let response = create_success_response(response); + Ok(warp::reply::with_status(warp::reply::json(&response), StatusCode::OK)) + } + Err(error) => Ok(warp::reply::with_status( + warp::reply::json(&error), + StatusCode::from_u16(error.code).unwrap(), + )), + } +} + +#[utoipa::path( + get, + path = "/v2/export_cron_task", + params( + ("cron_task_id" = String, Query, description = "Cron task ID to export") + ), + responses( + (status = 200, description = "Successfully exported cron task", body = Vec), + (status = 400, description = "Bad request", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn export_cron_task_handler( + node_commands_sender: Sender, + authorization: String, + query_params: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + + let cron_task_id_str = query_params.get("cron_task_id").ok_or_else(|| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid Query".to_string(), + message: "The cron_task_id parameter is required".to_string(), + }) + })?; + + // Parse cron_task_id to i64 + let cron_task_id: i64 = cron_task_id_str.parse().map_err(|_| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid Query".to_string(), + message: "The cron_task_id must be a valid integer.".to_string(), + }) + })?; + + let (res_sender, res_receiver) = async_channel::bounded(1); + node_commands_sender + .send(NodeCommand::V2ApiExportCronTask { + bearer, + cron_task_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(file_bytes) => { + // Return the raw bytes with appropriate headers + Ok(warp::reply::with_header( + warp::reply::with_status(file_bytes, StatusCode::OK), + "Content-Type", + "application/octet-stream", + )) + } + Err(error) => Ok(warp::reply::with_header( + warp::reply::with_status( + error.message.as_bytes().to_vec(), + StatusCode::from_u16(error.code).unwrap() + ), + "Content-Type", + "text/plain", + )) + } +} + #[derive(OpenApi)] #[openapi( paths( @@ -517,6 +643,8 @@ pub async fn get_cron_schedule_handler( update_cron_task_handler, force_execute_cron_task_handler, get_cron_schedule_handler, + import_cron_task_handler, + export_cron_task_handler, ), components( schemas(CronTask, CronTaskAction, APIError) diff --git a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs index 801e80de7..d15655225 100644 --- a/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs +++ b/shinkai-libs/shinkai-http-api/src/api_v2/api_v2_handlers_general.rs @@ -178,6 +178,20 @@ pub fn general_routes( .and(warp::header::("authorization")) .and_then(get_all_agents_handler); + let export_agent_route = warp::path("export_agent") + .and(warp::get()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::query::>()) + .and_then(export_agent_handler); + + let import_agent_route = warp::path("import_agent") + .and(warp::post()) + .and(with_sender(node_commands_sender.clone())) + .and(warp::header::("authorization")) + .and(warp::body::json()) + .and_then(import_agent_handler); + let test_llm_provider_route = warp::path("test_llm_provider") .and(warp::post()) .and(with_sender(node_commands_sender.clone())) @@ -207,6 +221,8 @@ pub fn general_routes( .or(update_agent_route) .or(get_agent_route) .or(get_all_agents_route) + .or(export_agent_route) + .or(import_agent_route) .or(test_llm_provider_route) } @@ -934,6 +950,104 @@ pub async fn get_all_agents_handler( } } +#[utoipa::path( + get, + path = "/v2/export_agent", + params( + ("agent_id" = String, Query, description = "Agent identifier") + ), + responses( + (status = 200, description = "Exported agent", body = Vec), + (status = 400, description = "Invalid agent identifier", body = APIError), + ) +)] +pub async fn export_agent_handler( + sender: Sender, + authorization: String, + query_params: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + + let agent_id = query_params + .get("agent_id") + .ok_or_else(|| { + warp::reject::custom(APIError { + code: 400, + error: "Invalid agent identifier".to_string(), + message: "Agent identifier is required".to_string(), + }) + })? + .to_string(); + + let (res_sender, res_receiver) = async_channel::bounded(1); + + sender + .send(NodeCommand::V2ApiExportAgent { + bearer, + agent_id, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(file_bytes) => { + // Return the raw bytes with appropriate headers + Ok(warp::reply::with_header( + warp::reply::with_status(file_bytes, StatusCode::OK), + "Content-Type", + "application/octet-stream", + )) + } + Err(error) => Ok(warp::reply::with_header( + warp::reply::with_status( + error.message.as_bytes().to_vec(), + StatusCode::from_u16(error.code).unwrap() + ), + "Content-Type", + "text/plain", + )) + } +} + +#[utoipa::path( + post, + path = "/v2/import_agent", + request_body = HashMap, + responses( + (status = 200, description = "Successfully imported agent", body = Value), + (status = 400, description = "Invalid URL or agent data", body = APIError), + (status = 500, description = "Internal server error", body = APIError) + ) +)] +pub async fn import_agent_handler( + sender: Sender, + authorization: String, + payload: HashMap, +) -> Result { + let bearer = authorization.strip_prefix("Bearer ").unwrap_or("").to_string(); + let url = payload.get("url").cloned().unwrap_or_default(); + + let (res_sender, res_receiver) = async_channel::bounded(1); + sender + .send(NodeCommand::V2ApiImportAgent { + bearer, + url, + res: res_sender, + }) + .await + .map_err(|_| warp::reject::reject())?; + + let result = res_receiver.recv().await.map_err(|_| warp::reject::reject())?; + + match result { + Ok(response) => Ok(warp::reply::json(&response)), + Err(error) => Err(warp::reject::custom(error)), + } +} + #[utoipa::path( post, path = "/v2/test_llm_provider", @@ -990,6 +1104,8 @@ pub async fn test_llm_provider_handler( add_agent_handler, remove_agent_handler, update_agent_handler, + import_agent_handler, + export_agent_handler, get_agent_handler, get_all_agents_handler, test_llm_provider_handler, diff --git a/shinkai-libs/shinkai-http-api/src/node_commands.rs b/shinkai-libs/shinkai-http-api/src/node_commands.rs index 5b598eafc..a88ad845c 100644 --- a/shinkai-libs/shinkai-http-api/src/node_commands.rs +++ b/shinkai-libs/shinkai-http-api/src/node_commands.rs @@ -255,6 +255,16 @@ pub enum NodeCommand { msg: ShinkaiMessage, res: Sender>, }, + V2ApiImportAgent { + bearer: String, + url: String, + res: Sender>, + }, + V2ApiExportAgent { + bearer: String, + agent_id: String, + res: Sender, APIError>>, + }, AvailableLLMProviders { full_profile_name: String, res: Sender, String>>, @@ -1169,4 +1179,14 @@ pub enum NodeCommand { file_name: String, res: Sender>, }, + V2ApiImportCronTask { + bearer: String, + url: String, + res: Sender>, + }, + V2ApiExportCronTask { + bearer: String, + cron_task_id: i64, + res: Sender, APIError>>, + }, }