Skip to content

Commit

Permalink
cp
Browse files Browse the repository at this point in the history
  • Loading branch information
nicarq committed Jan 16, 2024
1 parent 3f2e59c commit 6b83402
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 73 deletions.
2 changes: 1 addition & 1 deletion src/cron_tasks/cron_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ impl CronManager {
loop {
let jobs_to_process: HashMap<String, Vec<(String, CronTask)>> = {
let mut db_lock = db.lock().await;
db_lock.get_all_cron_tasks_from_all_profiles().unwrap_or(HashMap::new())
db_lock.get_all_cron_tasks_from_all_profiles(node_profile_name.clone()).unwrap_or(HashMap::new())
};
if !jobs_to_process.is_empty() {
shinkai_log(
Expand Down
174 changes: 103 additions & 71 deletions src/db/db_cron_task.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::{cmp::Ordering, collections::{HashMap, HashSet}};
use std::{
cmp::Ordering,
collections::{HashMap, HashSet},
};

use super::{db_errors::ShinkaiDBError, ShinkaiDB, Topic};
use super::{db::ProfileBoundWriteBatch, db_errors::ShinkaiDBError, ShinkaiDB, Topic};
use chrono::Utc;
use rocksdb::{IteratorMode, Options};
use serde::{Deserialize, Serialize};
use shinkai_message_primitives::schemas::shinkai_name::ShinkaiName;
use shinkai_message_primitives::{schemas::shinkai_name::ShinkaiName, shinkai_utils::shinkai_logging::{shinkai_log, ShinkaiLogOption, ShinkaiLogLevel}};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CronTask {
Expand Down Expand Up @@ -42,15 +45,17 @@ impl ShinkaiDB {
crawl_links: bool,
agent_id: String,
) -> Result<(), ShinkaiDBError> {
let profile = profile.get_profile_name().ok_or(ShinkaiDBError::InvalidProfileName("Invalid profile name".to_string()))?;

let cf_name_schedule = format!("{}_cron_task_schedule", profile);
let cf_name_prompt = format!("{}_cron_task_prompt", profile);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile);
let cf_name_url = format!("{}_cron_task_url", profile);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile);
let cf_name_created_at = format!("{}_cron_task_created_at", profile);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile);
let profile_name = profile
.get_profile_name()
.ok_or(ShinkaiDBError::InvalidProfileName("Invalid profile name".to_string()))?;

let cf_name_schedule = format!("{}_cron_task_schedule", profile_name);
let cf_name_prompt = format!("{}_cron_task_prompt", profile_name);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile_name);
let cf_name_url = format!("{}_cron_task_url", profile_name);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile_name);
let cf_name_created_at = format!("{}_cron_task_created_at", profile_name);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile_name);

let mut cf_opts = Options::default();
cf_opts.create_if_missing(true);
Expand Down Expand Up @@ -142,30 +147,36 @@ impl ShinkaiDB {

let cf_cron_queues = self.get_cf_handle(Topic::CronQueues)?;

let mut batch = rocksdb::WriteBatch::default();
batch.put_cf(cf_schedule, &task_id, &cron);
batch.put_cf(cf_prompt, &task_id, &prompt);
batch.put_cf(cf_subprompt, &task_id, &subprompt);
batch.put_cf(cf_url, &task_id, &url);
batch.put_cf(cf_crawl_links, &task_id, &crawl_links.to_string());
let mut pb_batch = ProfileBoundWriteBatch::new(&profile)?;

pb_batch.put_cf_pb(cf_schedule, &task_id, &cron);
pb_batch.put_cf_pb(cf_prompt, &task_id, &prompt);
pb_batch.put_cf_pb(cf_subprompt, &task_id, &subprompt);
pb_batch.put_cf_pb(cf_url, &task_id, &url);
pb_batch.put_cf_pb(cf_crawl_links, &task_id, &crawl_links.to_string());

let created_at = Utc::now().to_rfc3339();
batch.put_cf(cf_created_at, &task_id, &created_at);
batch.put_cf(cf_agent_id, &task_id, &agent_id);
batch.put_cf(cf_cron_queues, &task_id, &profile);
pb_batch.put_cf_pb(cf_created_at, &task_id, &created_at);
pb_batch.put_cf_pb(cf_agent_id, &task_id, &agent_id);
pb_batch.put_cf_pb(cf_cron_queues, &task_id, &profile_name);

self.write_pb(pb_batch)?;

self.db.write(batch)?;
Ok(())
}

pub fn remove_cron_task(&mut self, profile: String, task_id: String) -> Result<(), ShinkaiDBError> {
let cf_name_schedule = format!("{}_cron_task_schedule", profile);
let cf_name_prompt = format!("{}_cron_task_prompt", profile);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile);
let cf_name_url = format!("{}_cron_task_url", profile);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile);
let cf_name_created_at = format!("{}_cron_task_created_at", profile);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile);
pub fn remove_cron_task(&mut self, profile: ShinkaiName, task_id: String) -> Result<(), ShinkaiDBError> {
let profile_name = profile
.get_profile_name()
.ok_or(ShinkaiDBError::InvalidProfileName("Invalid profile name".to_string()))?;

let cf_name_schedule = format!("{}_cron_task_schedule", profile_name);
let cf_name_prompt = format!("{}_cron_task_prompt", profile_name);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile_name);
let cf_name_url = format!("{}_cron_task_url", profile_name);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile_name);
let cf_name_created_at = format!("{}_cron_task_created_at", profile_name);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile_name);
let cf_cron_queues = self.get_cf_handle(Topic::CronQueues)?;

let cf_schedule = self
Expand Down Expand Up @@ -224,109 +235,126 @@ impl ShinkaiDB {
task_id
)))?;

let mut batch = rocksdb::WriteBatch::default();
batch.delete_cf(cf_schedule, &task_id);
batch.delete_cf(cf_prompt, &task_id);
batch.delete_cf(cf_subprompt, &task_id);
batch.delete_cf(cf_url, &task_id);
batch.delete_cf(cf_crawl_links, &task_id);
batch.delete_cf(cf_created_at, &task_id);
batch.delete_cf(cf_agent_id, &task_id);
batch.delete_cf(cf_cron_queues, &task_id);

self.db.write(batch)?;
let mut pb_batch = ProfileBoundWriteBatch::new(&profile)?;
pb_batch.delete_cf_pb(cf_schedule, &task_id);
pb_batch.delete_cf_pb(cf_prompt, &task_id);
pb_batch.delete_cf_pb(cf_subprompt, &task_id);
pb_batch.delete_cf_pb(cf_url, &task_id);
pb_batch.delete_cf_pb(cf_crawl_links, &task_id);
pb_batch.delete_cf_pb(cf_created_at, &task_id);
pb_batch.delete_cf_pb(cf_agent_id, &task_id);
pb_batch.delete_cf_pb(cf_cron_queues, &task_id);

self.write_pb(pb_batch)?;
Ok(())
}

pub fn get_all_cron_tasks_from_all_profiles(&self) -> Result<HashMap<String, Vec<(String, CronTask)>>, ShinkaiDBError> {
pub fn get_all_cron_tasks_from_all_profiles(
&self,
node_name: ShinkaiName,
) -> Result<HashMap<String, Vec<(String, CronTask)>>, ShinkaiDBError> {
let cf_cron_queues = self.get_cf_handle(Topic::CronQueues)?;

let mut all_profiles = HashSet::new();
for result in self.db.iterator_cf(cf_cron_queues, IteratorMode::Start) {
match result {
Ok((_, value)) => {
let profile = String::from_utf8(value.to_vec()).unwrap();
eprintln!("get_all_cron_tasks_from_all_profiles profile: {:?}", profile);
shinkai_log(
ShinkaiLogOption::CronExecution,
ShinkaiLogLevel::Debug,
format!("get_all_cron_tasks_from_all_profiles profile: {:?}", profile).as_str(),
);
all_profiles.insert(profile);
}
Err(e) => return Err(e.into()),
}
}

let mut all_tasks = HashMap::new();
for profile in all_profiles.clone() {
let tasks = self.get_all_cron_tasks_for_profile(profile.clone())?;
let shinkai_profile = ShinkaiName::from_node_and_profile(node_name.get_node_name(), profile.clone())?;
let tasks = self.get_all_cron_tasks_for_profile(shinkai_profile)?;
for (task_id, task) in tasks {
all_tasks.entry(profile.clone()).or_insert_with(Vec::new).push((task_id, task));
all_tasks
.entry(profile.clone())
.or_insert_with(Vec::new)
.push((task_id, task));
}
}

Ok(all_tasks)
}

pub fn get_all_cron_tasks_for_profile(&self, profile: String) -> Result<HashMap<String, CronTask>, ShinkaiDBError> {
let cf_name_schedule = format!("{}_cron_task_schedule", profile);
let cf_name_prompt = format!("{}_cron_task_prompt", profile);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile);
let cf_name_url = format!("{}_cron_task_url", profile);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile);
let cf_name_created_at = format!("{}_cron_task_created_at", profile);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile);
pub fn get_all_cron_tasks_for_profile(
&self,
profile: ShinkaiName,
) -> Result<HashMap<String, CronTask>, ShinkaiDBError> {
let profile_name = profile
.get_profile_name()
.ok_or(ShinkaiDBError::InvalidProfileName("Invalid profile name".to_string()))?;
let cf_name_schedule = format!("{}_cron_task_schedule", profile_name);
let cf_name_prompt = format!("{}_cron_task_prompt", profile_name);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile_name);
let cf_name_url = format!("{}_cron_task_url", profile_name);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile_name);
let cf_name_created_at = format!("{}_cron_task_created_at", profile_name);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile_name);

let cf_schedule = self
.db
.cf_handle(&cf_name_schedule)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (name_schedule) not found for profile: {}",
profile
profile_name
)))?;

let cf_prompt = self
.db
.cf_handle(&cf_name_prompt)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (name_prompt) not found for profile: {}",
profile
profile_name
)))?;

let cf_subprompt = self
.db
.cf_handle(&cf_name_subprompt)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (name_prompt) not found for profile: {}",
profile
profile_name
)))?;

let cf_url = self
.db
.cf_handle(&cf_name_url)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (name_url) not found for profile: {}",
profile
profile_name
)))?;

let cf_crawl_links = self
.db
.cf_handle(&cf_name_crawl_links)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (crawl_links) not found for profile: {}",
profile
profile_name
)))?;

let cf_created_at = self
.db
.cf_handle(&cf_name_created_at)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (created_at) not found for profile: {}",
profile
profile_name
)))?;

let cf_agent_id = self
.db
.cf_handle(&cf_name_agent_id)
.ok_or(ShinkaiDBError::CronTaskNotFound(format!(
"Cron tasks (agent_id) not found for profile: {}",
profile
profile_name
)))?;

let mut tasks = HashMap::new();
Expand Down Expand Up @@ -369,14 +397,18 @@ impl ShinkaiDB {
Ok(tasks)
}

pub fn get_cron_task(&self, profile: String, task_id: String) -> Result<CronTask, ShinkaiDBError> {
let cf_name_schedule = format!("{}_cron_task_schedule", profile);
let cf_name_prompt = format!("{}_cron_task_prompt", profile);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile);
let cf_name_url = format!("{}_cron_task_url", profile);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile);
let cf_name_created_at = format!("{}_cron_task_created_at", profile);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile);
pub fn get_cron_task(&self, profile: ShinkaiName, task_id: String) -> Result<CronTask, ShinkaiDBError> {
let profile_name = profile
.get_profile_name()
.ok_or(ShinkaiDBError::InvalidProfileName("Invalid profile name".to_string()))?;

let cf_name_schedule = format!("{}_cron_task_schedule", profile_name);
let cf_name_prompt = format!("{}_cron_task_prompt", profile_name);
let cf_name_subprompt = format!("{}_cron_task_subprompt", profile_name);
let cf_name_url = format!("{}_cron_task_url", profile_name);
let cf_name_crawl_links = format!("{}_cron_task_crawl_links", profile_name);
let cf_name_created_at = format!("{}_cron_task_created_at", profile_name);
let cf_name_agent_id = format!("{}_cron_task_agent_id", profile_name);

let cf_schedule = self
.db
Expand Down
2 changes: 1 addition & 1 deletion src/network/node_devops_api_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use reqwest::StatusCode;
impl Node {
pub async fn api_private_devops_cron_list(&self, res: Sender<Result<String, APIError>>) -> Result<(), NodeError> {
// Call the get_all_cron_tasks_from_all_profiles function
match self.db.lock().await.get_all_cron_tasks_from_all_profiles() {
match self.db.lock().await.get_all_cron_tasks_from_all_profiles(self.node_profile_name.clone()) {
Ok(tasks) => {
eprintln!("Got {} cron tasks", tasks.len());
// If everything went well, send the tasks back as a JSON string
Expand Down

0 comments on commit 6b83402

Please sign in to comment.