Skip to content

Commit

Permalink
Merge pull request #751 from dcSpark/feature/import-export-agents
Browse files Browse the repository at this point in the history
Export / Import Agents and Crons Endpoints
  • Loading branch information
guillevalin authored Dec 26, 2024
2 parents 2a81333 + 7894b80 commit 5a61e79
Show file tree
Hide file tree
Showing 8 changed files with 766 additions and 70 deletions.
28 changes: 28 additions & 0 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,18 @@ impl Node {
.await;
});
}
NodeCommand::V2ApiExportAgent { bearer, agent_id, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_export_agent(db_clone, bearer, agent_id, res).await;
});
}
NodeCommand::V2ApiImportAgent { bearer, url, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_import_agent(db_clone, bearer, url, res).await;
});
}
NodeCommand::AvailableLLMProviders { full_profile_name, res } => {
let db_clone = self.db.clone();
let node_name_clone = self.node_name.clone();
Expand Down Expand Up @@ -2728,6 +2740,22 @@ impl Node {
let _ = Node::v2_api_get_cron_task_logs(db_clone, bearer, cron_task_id, res).await;
});
}
NodeCommand::V2ApiImportCronTask { bearer, url, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_import_cron_task(db_clone, bearer, url, res).await;
});
}
NodeCommand::V2ApiExportCronTask {
bearer,
cron_task_id,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_export_cron_task(db_clone, bearer, cron_task_id, res).await;
});
}
NodeCommand::V2ApiGenerateToolMetadataImplementation {
bearer,
job_id,
Expand Down
81 changes: 81 additions & 0 deletions shinkai-bin/shinkai-node/src/network/node_shareable_logic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use shinkai_http_api::node_api_router::APIError;
use shinkai_message_primitives::schemas::identity::{Identity, StandardIdentityType};
use std::sync::Arc;
use std::io::Read;
use tokio::sync::Mutex;
use tokio_util::bytes::Bytes;

use crate::managers::identity_manager::IdentityManager;
use crate::managers::identity_manager::IdentityManagerTrait;
Expand Down Expand Up @@ -186,3 +188,82 @@ pub async fn validate_message_main_logic(

Ok((msg, sender_subidentity))
}

pub struct ZipFileContents {
pub buffer: Vec<u8>,
pub archive: zip::ZipArchive<std::io::Cursor<Bytes>>,
}

pub async fn download_zip_file(url: String, file_name: String) -> Result<ZipFileContents, APIError> {
// Download the zip file
let response = match reqwest::get(&url).await {
Ok(response) => response,
Err(err) => {
return Err(APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Download Failed".to_string(),
message: format!("Failed to download agent from URL: {}", err),
});
}
};

// Get the bytes from the response
let bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Download Failed".to_string(),
message: format!("Failed to read response bytes: {}", err),
});
}
};

// Create a cursor from the bytes
let cursor = std::io::Cursor::new(bytes.clone());

// Create a zip archive from the cursor
let mut archive = match zip::ZipArchive::new(cursor) {
Ok(archive) => archive,
Err(err) => {
return Err(APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Invalid Zip File".to_string(),
message: format!("Failed to read zip archive: {}", err),
});
}
};

// Extract and parse file
let mut buffer = Vec::new();
{
let mut file = match archive.by_name(&file_name) {
Ok(file) => file,
Err(_) => {
return Err(APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Invalid Zip File".to_string(),
message: format!("Archive does not contain {}", file_name),
});
}
};

// Read the file contents into a buffer
if let Err(err) = file.read_to_end(&mut buffer) {
return Err(APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Read Error".to_string(),
message: format!("Failed to read file contents: {}", err),
});
}
}

// Create a new cursor and archive for returning
let return_cursor = std::io::Cursor::new(bytes);
let return_archive = zip::ZipArchive::new(return_cursor).unwrap();

Ok(ZipFileContents {
buffer,
archive: return_archive,
})
}
178 changes: 177 additions & 1 deletion shinkai-bin/shinkai-node/src/network/v2_api/api_v2_commands.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
use std::{env, sync::Arc};
use std::path::Path;
use std::fs::File;
use std::io::Write;

use serde_json::{json, Value};
use tokio::fs;
use zip::{write::FileOptions, ZipWriter};

use async_channel::Sender;
use ed25519_dalek::{SigningKey, VerifyingKey};
Expand Down Expand Up @@ -43,7 +50,7 @@ use x25519_dalek::PublicKey as EncryptionPublicKey;
use crate::{
llm_provider::{job_manager::JobManager, llm_stopper::LLMStopper},
managers::{identity_manager::IdentityManagerTrait, IdentityManager},
network::{node_error::NodeError, Node},
network::{node_error::NodeError, Node, node_shareable_logic::download_zip_file},
tools::tool_generation,
utils::update_global_identity::update_global_identity_name,
};
Expand Down Expand Up @@ -1591,4 +1598,173 @@ impl Node {

Ok(())
}

pub async fn v2_api_export_agent(
db: Arc<SqliteManager>,
bearer: String,
agent_id: String,
res: Sender<Result<Vec<u8>, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

// Retrieve the agent from the database
match db.get_agent(&agent_id) {
Ok(Some(agent)) => {
// Serialize the agent to JSON bytes
let agent_bytes = match serde_json::to_vec(&agent) {
Ok(bytes) => bytes,
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to serialize agent: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

// Create a temporary zip file
let name = format!("{}.zip", agent.agent_id.replace(':', "_"));
let path = std::env::temp_dir().join(&name);
let file = match File::create(&path) {
Ok(file) => file,
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create zip file: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

let mut zip = ZipWriter::new(file);

// Add the agent JSON to the zip file
if let Err(err) = zip.start_file::<_, ()>("__agent.json", FileOptions::default()) {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create agent file in zip: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}

if let Err(err) = zip.write_all(&agent_bytes) {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to write agent data to zip: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}

// Finalize the zip file
if let Err(err) = zip.finish() {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to finalize zip file: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}

// Read the zip file into memory
match fs::read(&path).await {
Ok(file_bytes) => {
// Clean up the temporary file
if let Err(err) = fs::remove_file(&path).await {
eprintln!("Warning: Failed to remove temporary file: {}", err);
}
let _ = res.send(Ok(file_bytes)).await;
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to read zip file: {}", err),
};
let _ = res.send(Err(api_error)).await;
}
}
}
Ok(None) => {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("Agent not found: {}", agent_id),
};
let _ = res.send(Err(api_error)).await;
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to retrieve agent: {}", err),
};
let _ = res.send(Err(api_error)).await;
}
}

Ok(())
}

pub async fn v2_api_import_agent(
db: Arc<SqliteManager>,
bearer: String,
url: String,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}

let zip_contents = match download_zip_file(url, "__agent.json".to_string()).await {
Ok(contents) => contents,
Err(err) => {
let api_error = APIError {
code: StatusCode::BAD_REQUEST.as_u16(),
error: "Invalid Agent Zip".to_string(),
message: format!("Failed to extract agent.json: {:?}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
};

// Parse the JSON into an Agent
let agent: Agent = serde_json::from_slice(&zip_contents.buffer).map_err(|e| NodeError::from(e.to_string()))?;

// Save the agent to the database
match db.add_agent(agent.clone(), &agent.full_identity_name) {
Ok(_) => {
let response = json!({
"status": "success",
"message": "Agent imported successfully",
"agent_id": agent.agent_id,
"agent": agent
});
let _ = res.send(Ok(response)).await;
}
Err(err) => {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Database Error".to_string(),
message: format!("Failed to save agent to database: {}", err),
};
let _ = res.send(Err(api_error)).await;
}
}

Ok(())
}
}
Loading

0 comments on commit 5a61e79

Please sign in to comment.