Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin' into feature/mounts
Browse files Browse the repository at this point in the history
  • Loading branch information
acedward committed Dec 22, 2024
2 parents 8d89a06 + 096ad6f commit 20ff2f6
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 411 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ chrono = "0.4"
serde_json = "1.0.117"
anyhow = "1.0"
blake3 = "1.2.0"
shinkai_tools_runner = "0.9.5"
shinkai_tools_runner = "0.9.6"
serde = "1.0.188"
base64 = "0.22.0"
reqwest = "0.11.27"
Expand Down
51 changes: 35 additions & 16 deletions shinkai-bin/shinkai-node/src/cron_tasks/cron_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
sync::{Arc, Weak},
};

use chrono::{Local, TimeZone, Utc};
use chrono::{Local, Utc};
use ed25519_dalek::SigningKey;
use futures::Future;
use shinkai_message_primitives::{
Expand Down Expand Up @@ -326,10 +326,16 @@ impl CronManager {
job_creation_info,
llm_provider,
} => {
// Clone job_creation_info and set is_hidden to true if not defined
let mut job_creation_info_clone = job_creation_info.clone();
if job_creation_info_clone.is_hidden.is_none() {
job_creation_info_clone.is_hidden = Some(true);
}

let job_id = job_manager
.lock()
.await
.process_job_creation(job_creation_info, &shinkai_profile, &llm_provider)
.process_job_creation(job_creation_info_clone, &shinkai_profile, &llm_provider)
.await?;

// Update the job configuration
Expand Down Expand Up @@ -411,10 +417,10 @@ impl CronManager {
result
}

async fn log_success_to_sqlite(db: &Arc<SqliteManager>, task_id: i64) {
async fn log_success_to_sqlite(db: &Arc<SqliteManager>, task_id: i64, job_id: Option<String>) {
let execution_time = Local::now().to_rfc3339();
let db = db;
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, true, None) {
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, true, None, job_id) {
eprintln!("Failed to log success to SQLite: {}", err);
}
}
Expand All @@ -434,11 +440,12 @@ impl CronManager {
let bearer = match db.read_api_v2_key() {
Ok(Some(token)) => token,
Ok(None) => {
Self::log_error_to_sqlite(&db, task_id, "Bearer token not found").await;
Self::log_error_to_sqlite(&db, task_id, "Bearer token not found", None).await;
return Ok(());
}
Err(err) => {
Self::log_error_to_sqlite(&db, task_id, &format!("Failed to retrieve bearer token: {}", err)).await;
Self::log_error_to_sqlite(&db, task_id, &format!("Failed to retrieve bearer token: {}", err), None)
.await;
return Ok(());
}
};
Expand All @@ -453,33 +460,45 @@ impl CronManager {
identity_manager_clone,
job_manager_clone,
bearer,
job_message_clone,
job_message_clone.clone(),
encryption_secret_key_clone,
encryption_public_key_clone,
signing_secret_key_clone,
res_tx,
)
.await
{
Self::log_error_to_sqlite(&db, task_id, &format!("Failed to send job message: {}", err)).await;
Self::log_error_to_sqlite(
&db,
task_id,
&format!("Failed to send job message: {}", err),
Some(job_message_clone.job_id),
)
.await;
return Ok(());
}

// Handle the response only if sending was successful
if let Err(err) = res_rx.recv().await {
Self::log_error_to_sqlite(&db, task_id, &format!("Failed to receive response: {}", err)).await;
Self::log_error_to_sqlite(
&db,
task_id,
&format!("Failed to receive response: {}", err),
Some(job_message_clone.job_id),
)
.await;
} else {
// Log success if the response is received successfully
Self::log_success_to_sqlite(&db, task_id).await;
Self::log_success_to_sqlite(&db, task_id, Some(job_message_clone.job_id)).await;
}

Ok(())
}

async fn log_error_to_sqlite(db: &Arc<SqliteManager>, task_id: i64, error_message: &str) {
async fn log_error_to_sqlite(db: &Arc<SqliteManager>, task_id: i64, error_message: &str, job_id: Option<String>) {
let execution_time = Local::now().to_rfc3339();
let db = db;
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, false, Some(error_message)) {
if let Err(err) = db.add_cron_task_execution(task_id, &execution_time, false, Some(error_message), job_id) {
eprintln!("Failed to log error to SQLite: {}", err);
}
}
Expand Down Expand Up @@ -631,7 +650,7 @@ mod tests {

#[test]
fn test_should_execute_specific_minute() {
let now = Utc::now();
let now = Local::now();
let next_minute = (now.minute() + 1) % 60;
let cron = format!("{} * * * *", next_minute);
let task = create_test_cron_task(&cron);
Expand All @@ -642,7 +661,7 @@ mod tests {

#[test]
fn test_should_not_execute_past_time() {
let now = Utc::now();
let now = Local::now();
let past_minute = if now.minute() == 0 { 59 } else { now.minute() - 1 };
let cron = format!("{} * * * *", past_minute);
let task = create_test_cron_task(&cron);
Expand All @@ -667,7 +686,7 @@ mod tests {

#[test]
fn test_should_execute_within_interval() {
let now = Utc::now();
let now = Local::now();
let next_minute = (now.minute() + 1) % 60;

// Create a cron expression for the next minute, any hour/day/month
Expand All @@ -688,7 +707,7 @@ mod tests {

#[test]
fn test_should_not_execute_outside_interval() {
let now = Utc::now();
let now = Local::now();
let future_minute = (now.minute() + 2) % 60;
let cron = format!("{} * * * *", future_minute);
let task = create_test_cron_task(&cron);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl LLMService for Claude {

// Print payload as a pretty JSON string
match serde_json::to_string_pretty(&payload) {
Ok(pretty_json) => eprintln!("Payload: {}", pretty_json),
Ok(pretty_json) => eprintln!("cURL Payload: {}", pretty_json),
Err(e) => eprintln!("Failed to serialize payload: {:?}", e),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl LLMService for Gemini {
|| std::env::var("LOG_ALL").unwrap_or_default() == "1"
{
match serde_json::to_string_pretty(&payload) {
Ok(pretty_json) => eprintln!("Payload: {}", pretty_json),
Ok(pretty_json) => eprintln!("cURL Payload: {}", pretty_json),
Err(e) => eprintln!("Failed to serialize payload: {:?}", e),
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl LLMService for OpenAI {

// Print payload as a pretty JSON string
match serde_json::to_string_pretty(&payload) {
Ok(pretty_json) => eprintln!("Payload: {}", pretty_json),
Ok(pretty_json) => eprintln!("cURL Payload: {}", pretty_json),
Err(e) => eprintln!("Failed to serialize payload: {:?}", e),
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl LLMService for OpenRouter {

// Print payload as a pretty JSON string
match serde_json::to_string_pretty(&payload) {
Ok(pretty_json) => eprintln!("Payload: {}", pretty_json),
Ok(pretty_json) => eprintln!("cURL Payload: {}", pretty_json),
Err(e) => eprintln!("Failed to serialize payload: {:?}", e),
};

Expand Down
40 changes: 40 additions & 0 deletions shinkai-bin/shinkai-node/src/managers/tool_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,40 @@ impl ToolRouter {
Ok(())
}

async fn import_tools_from_directory(db: Arc<SqliteManager>) -> Result<(), ToolError> {
let url = env::var("SHINKAI_TOOLS_DIRECTORY_URL")
.map_err(|_| ToolError::MissingConfigError("SHINKAI_TOOLS_DIRECTORY_URL not set".to_string()))?;

let response = reqwest::get(url).await.map_err(|e| ToolError::RequestError(e))?;

if response.status() != 200 {
return Err(ToolError::ExecutionError(format!(
"Import tools request returned a non OK status: {}",
response.status()
)));
}

let tools: Vec<serde_json::Value> = response
.json()
.await
.map_err(|e| ToolError::ParseError(format!("Failed to parse tools directory: {}", e)))?;

for tool in tools {
let tool_url = tool["file"]
.as_str()
.ok_or_else(|| ToolError::ParseError("Missing or invalid file URL in tool definition".to_string()))?;

let tool_name = tool["name"].as_str().unwrap_or("unknown");

match Node::v2_api_import_tool_internal(db.clone(), fetch_node_environment(), tool_url.to_string()).await {
Ok(_) => println!("Successfully imported tool {}", tool_name),
Err(e) => eprintln!("Failed to import tool {}: {:#?}", tool_name, e),
}
}

Ok(())
}

pub async fn add_static_prompts(&self, _generator: &Box<dyn EmbeddingGenerator>) -> Result<(), ToolError> {
// Check if ONLY_TESTING_PROMPTS is set
if env::var("ONLY_TESTING_PROMPTS").unwrap_or_default() == "1"
Expand Down Expand Up @@ -219,9 +253,15 @@ impl ToolRouter {

{
for (name, definition) in tools {
// Skip tools that start with "demo" if not only_testing_js_tools
if !only_testing_js_tools && name.starts_with("demo") {
continue;
}
// Skip tools that are not in the allowed list if only_testing_js_tools is true
if only_testing_js_tools && !allowed_tools.contains(&name.as_str()) {
continue; // Skip tools that are not in the allowed list
}

println!("Adding JS tool: {}", name);

let toolkit = JSToolkit::new(&name, vec![definition.clone()]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ impl Node {
.into_iter()
.map(|log| {
json!({
"job_id": log.3.as_ref().map_or("", |j| j),
"task_id": task_id.to_string(),
"execution_time": log.0,
"success": log.1,
Expand Down
Loading

0 comments on commit 20ff2f6

Please sign in to comment.