Skip to content

Commit

Permalink
app_file endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
acedward committed Dec 20, 2024
1 parent 2966cf6 commit 5c32ef5
Show file tree
Hide file tree
Showing 8 changed files with 765 additions and 10 deletions.
94 changes: 87 additions & 7 deletions shinkai-bin/shinkai-node/src/network/handle_commands_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2607,11 +2607,7 @@ impl Node {
let _ = Node::v2_api_import_tool(db_clone, bearer, node_env, url, res).await;
});
}
NodeCommand::V2ApiRemoveTool {
bearer,
tool_key,
res,
} => {
NodeCommand::V2ApiRemoveTool { bearer, tool_key, res } => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_remove_tool(db_clone, bearer, tool_key, res).await;
Expand Down Expand Up @@ -2684,7 +2680,9 @@ impl Node {
let db_clone = Arc::clone(&self.db);
let cron_manager_clone = self.cron_manager.clone().unwrap();
tokio::spawn(async move {
let _ = Node::v2_api_force_execute_cron_task(db_clone, cron_manager_clone, bearer, cron_task_id, res).await;
let _ =
Node::v2_api_force_execute_cron_task(db_clone, cron_manager_clone, bearer, cron_task_id, res)
.await;
});
}
NodeCommand::V2ApiGetCronSchedule { bearer, res } => {
Expand Down Expand Up @@ -2774,7 +2772,12 @@ impl Node {
let _ = Node::v2_export_messages_from_inbox(db_clone, bearer, inbox_name, format, res).await;
});
}
NodeCommand::V2ApiSearchShinkaiTool { bearer, query, agent_or_llm, res } => {
NodeCommand::V2ApiSearchShinkaiTool {
bearer,
query,
agent_or_llm,
res,
} => {
let db_clone = Arc::clone(&self.db);
tokio::spawn(async move {
let _ = Node::v2_api_search_shinkai_tool(db_clone, bearer, query, agent_or_llm, res).await;
Expand Down Expand Up @@ -2877,6 +2880,83 @@ impl Node {
.await;
});
}

NodeCommand::V2ApiUploadAppFile {
bearer,
tool_id,
app_id,
file_name,
file_data,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_upload_app_file(
db_clone, bearer, tool_id, app_id, file_name, file_data, node_env, res,
)
.await;
});
}
NodeCommand::V2ApiGetAppFile {
bearer,
tool_id,
app_id,
file_name,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ =
Node::v2_api_get_app_file(db_clone, bearer, tool_id, app_id, file_name, node_env, res).await;
});
}
NodeCommand::V2ApiUpdateAppFile {
bearer,
tool_id,
app_id,
file_name,
new_name,
file_data,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_update_app_file(
db_clone, bearer, tool_id, app_id, file_name, new_name, file_data, node_env, res,
)
.await;
});
}
NodeCommand::V2ApiListAppFiles {
bearer,
tool_id,
app_id,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ = Node::v2_api_list_app_files(db_clone, bearer, tool_id, app_id, node_env, res).await;
});
}
NodeCommand::V2ApiDeleteAppFile {
bearer,
tool_id,
app_id,
file_name,
res,
} => {
let db_clone = Arc::clone(&self.db);
let node_env = fetch_node_environment();
tokio::spawn(async move {
let _ =
Node::v2_api_delete_app_file(db_clone, bearer, tool_id, app_id, file_name, node_env, res).await;
});
}

_ => (),
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use crate::network::node_error::NodeError;
use crate::network::Node;
use crate::utils::environment::NodeEnvironment;

use async_channel::Sender;
use reqwest::StatusCode;
use serde_json::Value;
use shinkai_sqlite::SqliteManager;

use std::path::PathBuf;
use std::sync::Arc;

use shinkai_http_api::node_api_router::APIError;

fn get_app_folder_path(node_env: NodeEnvironment, app_id: String) -> PathBuf {
let mut origin_path: PathBuf = PathBuf::from(node_env.node_storage_path.clone().unwrap_or_default());
origin_path.push("app_files");
origin_path.push(app_id);
origin_path
}

impl Node {
pub async fn v2_api_upload_app_file(
db: Arc<SqliteManager>,
bearer: String,
tool_id: String,
app_id: String,
file_name: String,
file_data: Vec<u8>,
node_env: NodeEnvironment,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
if !app_folder_path.exists() {
let result = std::fs::create_dir_all(&app_folder_path);
if let Err(err) = result {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to create directory: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
}
let file_path = app_folder_path.join(file_name.clone());
if let Err(err) = std::fs::write(&file_path, file_data) {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to write file: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
let _ = res.send(Ok(Value::String(file_name))).await;
Ok(())
}

pub async fn v2_api_get_app_file(
db: Arc<SqliteManager>,
bearer: String,
tool_id: String,
app_id: String,
file_name: String,
node_env: NodeEnvironment,
res: Sender<Result<Vec<u8>, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
let file_path = app_folder_path.join(file_name.clone());
if !file_path.exists() {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("File {} not found", file_name),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
let file_bytes = std::fs::read(&file_path)?;
let _ = res.send(Ok(file_bytes)).await;
Ok(())
}

pub async fn v2_api_update_app_file(
db: Arc<SqliteManager>,
bearer: String,
tool_id: String,
app_id: String,
file_name: String,
new_name: Option<String>,
file_data: Option<Vec<u8>>,
node_env: NodeEnvironment,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
let file_path = app_folder_path.join(file_name.clone());
if !file_path.exists() {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("File {} not found", file_name),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}

if let Some(file_data) = file_data {
std::fs::write(&file_path, file_data)?;
}

if let Some(new_name) = new_name.clone() {
let new_file_path = app_folder_path.join(new_name.clone());
std::fs::rename(&file_path, &new_file_path)?;
}

let _ = res.send(Ok(Value::String(new_name.unwrap_or_default()))).await;
Ok(())
}

pub async fn v2_api_list_app_files(
db: Arc<SqliteManager>,
bearer: String,
tool_id: String,
app_id: String,
node_env: NodeEnvironment,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
let files = std::fs::read_dir(&app_folder_path);
if let Err(err) = files {
let api_error = APIError {
code: StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
error: "Internal Server Error".to_string(),
message: format!("Failed to read directory: {}", err),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
let mut file_list = Vec::new();
let files = files.unwrap();
for file in files {
if let Ok(file) = file {
let file_name = file.file_name().to_string_lossy().to_string();
file_list.push(Value::String(file_name));
}
}
let _ = res.send(Ok(Value::Array(file_list))).await;
Ok(())
}

pub async fn v2_api_delete_app_file(
db: Arc<SqliteManager>,
bearer: String,
tool_id: String,
app_id: String,
file_name: String,
node_env: NodeEnvironment,
res: Sender<Result<Value, APIError>>,
) -> Result<(), NodeError> {
// Validate the bearer token
if Self::validate_bearer_token(&bearer, db.clone(), &res).await.is_err() {
return Ok(());
}
let app_folder_path = get_app_folder_path(node_env, app_id);
let file_path = app_folder_path.join(file_name.clone());
if !file_path.exists() {
let api_error = APIError {
code: StatusCode::NOT_FOUND.as_u16(),
error: "Not Found".to_string(),
message: format!("File {} not found", file_name),
};
let _ = res.send(Err(api_error)).await;
return Ok(());
}
std::fs::remove_file(&file_path)?;
let _ = res.send(Ok(Value::String(file_name))).await;
Ok(())
}
}
1 change: 1 addition & 0 deletions shinkai-bin/shinkai-node/src/network/v2_api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod api_v2_commands;
pub mod api_v2_commands_app_files;
pub mod api_v2_commands_cron;
pub mod api_v2_commands_ext_agent_offers;
pub mod api_v2_commands_jobs;
Expand Down
5 changes: 3 additions & 2 deletions shinkai-bin/shinkai-node/tests/it/utils/node_test_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,10 @@ pub async fn api_execute_tool(
app_id: String,
llm_provider: String,
extra_config: Map<String, Value>,
oauth: Map<String, Value>,
_oauth: Map<String, Value>,
) -> Result<Value, APIError> {
let (res_sender, res_receiver) = async_channel::bounded(1);

let mounts = None;
node_commands_sender
.send(NodeCommand::V2ApiExecuteTool {
bearer,
Expand All @@ -696,6 +696,7 @@ pub async fn api_execute_tool(
app_id,
llm_provider,
extra_config,
mounts,
res: res_sender,
})
.await
Expand Down
Loading

0 comments on commit 5c32ef5

Please sign in to comment.