diff --git a/Cargo.toml b/Cargo.toml index 309ed31f..3008c9f4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ members = [ ] [workspace.package] -version = "0.13.0" +version = "0.14.0" edition = "2021" [workspace.dependencies] diff --git a/dfx.json b/dfx.json index 90916411..658bb22b 100644 --- a/dfx.json +++ b/dfx.json @@ -31,9 +31,9 @@ "type": "custom" }, "log_canister": { - "build": "bash scripts/build_log_canister.sh", + "build": "", "candid": "target/wasm32-unknown-unknown/release/log_canister.did", - "wasm": "target/wasm32-unknown-unknown/release/examples/log_canister.wasm", + "wasm": "target/wasm32-unknown-unknown/release/log_canister.wasm", "type": "custom" }, "dummy_scheduler_canister": { diff --git a/ic-task-scheduler/Cargo.toml b/ic-task-scheduler/Cargo.toml index 4298e4a3..99884916 100644 --- a/ic-task-scheduler/Cargo.toml +++ b/ic-task-scheduler/Cargo.toml @@ -5,8 +5,11 @@ edition.workspace = true [dependencies] bincode = { workspace = true } +candid = { workspace = true } +ic-cdk-timers = { workspace = true } ic-kit = { path = "../ic-kit" } ic-stable-structures = { path = "../ic-stable-structures" } +log = { workspace = true } parking_lot = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } @@ -14,6 +17,7 @@ thiserror = { workspace = true } [dev-dependencies] anyhow = { workspace = true } candid = { workspace = true } +ic-canister-client = { path = "../ic-canister-client", features = ["pocket-ic-client"]} ic-exports = { path = "../ic-exports", features = ["pocket-ic-tests-async"] } once_cell = { workspace = true } rand = { workspace = true } diff --git a/ic-task-scheduler/src/error.rs b/ic-task-scheduler/src/error.rs index 47608544..72af32df 100644 --- a/ic-task-scheduler/src/error.rs +++ b/ic-task-scheduler/src/error.rs @@ -1,9 +1,12 @@ +use candid::CandidType; +use serde::{Deserialize, Serialize}; use thiserror::Error; -#[derive(Debug, Error)] +#[derive(CandidType, Debug, Error, PartialEq, Eq, Serialize, Deserialize)] pub enum SchedulerError { #[error("TaskExecutionFailed: {0}")] TaskExecutionFailed(String), - #[error("storage error: {0}")] - StorageError(#[from] ic_stable_structures::Error), } + +/// Result type for the scheduler +pub type Result = std::result::Result; diff --git a/ic-task-scheduler/src/lib.rs b/ic-task-scheduler/src/lib.rs index 8bb55c75..15565b92 100644 --- a/ic-task-scheduler/src/lib.rs +++ b/ic-task-scheduler/src/lib.rs @@ -4,6 +4,4 @@ pub mod scheduler; pub mod task; mod time; -pub use error::SchedulerError; -/// Result type for the scheduler -pub type Result = std::result::Result; +pub use error::{Result, SchedulerError}; diff --git a/ic-task-scheduler/src/retry.rs b/ic-task-scheduler/src/retry.rs index 0f022d5e..09e00b01 100644 --- a/ic-task-scheduler/src/retry.rs +++ b/ic-task-scheduler/src/retry.rs @@ -1,10 +1,11 @@ use core::fmt::Debug; +use candid::CandidType; use serde::{Deserialize, Serialize}; /// Defines the strategy to apply in case of a failure. /// This is applied, for example, when a task execution fails -#[derive(Debug, Deserialize, Serialize, PartialEq, Eq, Clone)] +#[derive(CandidType, Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct RetryStrategy { pub retry_policy: RetryPolicy, pub backoff_policy: BackoffPolicy, @@ -30,7 +31,7 @@ impl RetryStrategy { } // Defines the retry policy of a RetryStrategy -#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] +#[derive(CandidType, Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub enum RetryPolicy { /// No Retry attempts defined None, @@ -55,7 +56,7 @@ impl RetryPolicy { } // Defines the backoff policy of a RetryStrategy -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] +#[derive(CandidType, Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub enum BackoffPolicy { /// No backoff, the retry will be attempted without waiting None, diff --git a/ic-task-scheduler/src/scheduler.rs b/ic-task-scheduler/src/scheduler.rs index 18b97bb3..3a7b039c 100644 --- a/ic-task-scheduler/src/scheduler.rs +++ b/ic-task-scheduler/src/scheduler.rs @@ -1,360 +1,286 @@ -use std::collections::HashSet; -use std::future::Future; -use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use ic_stable_structures::IterableUnboundedMapStructure; +use log::{debug, warn}; use parking_lot::Mutex; -use crate::task::{ScheduledTask, Task}; +use crate::task::{InnerScheduledTask, ScheduledTask, Task, TaskStatus}; use crate::time::time_secs; -use crate::{Result, SchedulerError}; - -/// Internal type used to store tasks to be processed and processing tasks -type TaskQueue = Arc>>; - -/// The state of a task execution. -/// This is reported when the `SaveStateQueryCallback` is called. -#[derive(Debug)] -pub enum TaskExecutionState { - /// Reported when tasks to be executed are scheduled - Scheduled, - /// Reported when a task fails. - Failed(u32, SchedulerError), - /// Reported when a task starts executing. - Executing(u32), - /// Reported when a task completes successfully. - Completed(u32), - /// Reported when a task panics. - Panicked(u32), -} +use crate::SchedulerError; + +type TaskCompletionCallback = Box) + Send>; -type OnStateChangeCallback = - dyn Fn(TaskExecutionState) -> Pin>>> + Send + Sync; +const DEFAULT_RUNNING_TASK_TIMEOUT_SECS: u64 = 120; /// A scheduler is responsible for executing tasks. -pub struct Scheduler -where +pub struct Scheduler< T: 'static + Task, - P: 'static + IterableUnboundedMapStructure>, -{ + P: 'static + IterableUnboundedMapStructure>, +> { pending_tasks: Arc>, phantom: std::marker::PhantomData, - /// Queue containing tasks to be processed - tasks_to_be_processed: TaskQueue, - /// Tasks which are currently being processed - tasks_running: TaskQueue, - /// Callback to be called to save the current canister state to prevent panicking tasks. - on_execution_state_changed_callback: Option>>, + on_completion_callback: Arc>>, + running_task_timeout_secs: AtomicU64, } -impl Scheduler -where - T: 'static + Task, - P: 'static + IterableUnboundedMapStructure>, +impl>> + Scheduler { /// Create a new scheduler. - /// - /// A callback `on_execution_state_changed_callback` is called every time the state of a task is changed. - /// By performing an inter-canister call in the callback, you can force the state to be persisted even in case of - /// panics. This allows the scheduler to deal with panicking tasks. - pub fn new( - pending_tasks: P, - on_execution_state_changed_callback: Option>, - ) -> Result { - Ok(Self { + pub fn new(pending_tasks: P) -> Self { + Self { pending_tasks: Arc::new(Mutex::new(pending_tasks)), phantom: std::marker::PhantomData, - on_execution_state_changed_callback: on_execution_state_changed_callback.map(Arc::new), - tasks_to_be_processed: Arc::new(Mutex::new(HashSet::new())), - tasks_running: Arc::new(Mutex::new(HashSet::new())), - }) + on_completion_callback: Arc::new(None), + running_task_timeout_secs: AtomicU64::new(DEFAULT_RUNNING_TASK_TIMEOUT_SECS), + } + } + + /// Set the timeout of a running task. If a task is running for more time the timeout, it will be + /// considered as stuck or panicked. + /// The default value is 120 seconds. + pub fn set_running_task_timeout(&mut self, timeout_secs: u64) { + debug!("Setting running task timeout to {} seconds", timeout_secs); + self.running_task_timeout_secs + .store(timeout_secs, Ordering::Relaxed); + } + + /// Set a callback to be called when a task execution completes. + pub fn on_completion_callback)>(&mut self, cb: F) { + self.on_completion_callback = Arc::new(Some(Box::new(cb))); } /// Execute all pending tasks. /// Each task is executed asynchronously in a dedicated ic_cdk::spawn call. /// This function does not wait for the tasks to complete. /// Returns the number of tasks that have been launched. - pub async fn run(&mut self) -> Result { - self.run_with_timestamp(time_secs()).await + pub fn run(&self) -> Result { + self.run_with_timestamp(time_secs()) } - async fn run_with_timestamp(&mut self, now_timestamp_secs: u64) -> Result { - let mut task_execution_started = 0; - - // checks tasks that are still in tasks running (this could indicate a panic happened in the previous run) - let tasks_running_count = self.tasks_running.lock().len(); - match tasks_running_count { - 0 => { - // HAPPY PATH: if there are no processing tasks, initialize the tasks o be processed - self.init_tasks_to_be_processed(now_timestamp_secs).await?; - } - 1 => { - // if there is only one processing task, we can assume that it panicked - // delete that task and mark it as panicked - let task = *self.tasks_running.lock().iter().next().unwrap(); - self.delete_unprocessable_task(task).await?; - - // eventually reschedule the tasks to be processed - self.init_tasks_to_be_processed(now_timestamp_secs).await?; - } - _ => { - // if there is more than one task to be processed, keep only half of them - self.split_processing_tasks().await?; - } - } + fn run_with_timestamp(&self, now_timestamp_secs: u64) -> Result { + debug!("Scheduler - Running tasks"); + let mut to_be_scheduled_tasks = Vec::new(); + let mut out_of_time_tasks = Vec::new(); + let running_task_timeout_secs = self.running_task_timeout_secs.load(Ordering::Relaxed); - // iterate over tasks to be processed, and execute it one by one - let tasks_to_be_processed: Vec = - self.tasks_to_be_processed.lock().iter().copied().collect(); - for task_id in tasks_to_be_processed { + { let lock = self.pending_tasks.lock(); - let mut task = match lock.get(&task_id) { - Some(task) => task, - None => continue, - }; - drop(lock); - - task_execution_started += 1; - let key = task_id; - let mut task_scheduler = self.clone(); - Self::spawn(async move { - // put task to processing tasks - task_scheduler - .put_task_to_processing_tasks(key) - .await - .unwrap(); - - // execute the task - if let Err(err) = task.task.execute(Box::new(task_scheduler.clone())).await { - task.options.failures += 1; - let (should_retry, retry_delay) = task - .options - .retry_strategy - .should_retry(task.options.failures); - if should_retry { - // remove task from processing task, but don't report state - task_scheduler.remove_task_from_processing_tasks(key); - - // re-add task to the queue - task.options.execute_after_timestamp_in_secs = - now_timestamp_secs + (retry_delay as u64); - task_scheduler.append_task(task.clone()) - } else { - // remove task from processing and port its failure - task_scheduler - .remove_failed_task_from_processing_tasks(key, err) - .await - .unwrap(); + for (task_key, task) in lock.iter() { + match task.status { + TaskStatus::Waiting { .. } => { + if task.options.execute_after_timestamp_in_secs <= now_timestamp_secs { + debug!("Scheduler - Task {} scheduled to be processed", task_key); + to_be_scheduled_tasks.push(task_key); + } } - } else { - // in case of success, remove task from queue and report success - task_scheduler - .remove_completed_task_from_processing_tasks(key) - .await - .unwrap(); + TaskStatus::Running { timestamp_secs } + | TaskStatus::Scheduled { timestamp_secs } => { + warn!( + "Scheduler - Task {} was in Scheduled or Running status for more than {} seconds, it could be stuck or panicked. Removing it from the scheduler.", + task_key, running_task_timeout_secs + ); + if timestamp_secs + running_task_timeout_secs < now_timestamp_secs { + out_of_time_tasks.push(task_key); + } + } + TaskStatus::Completed { .. } + | TaskStatus::TimeoutOrPanic { .. } + | TaskStatus::Failed { .. } => (), } - }); + } } - Ok(task_execution_started) - } - - // We use tokio for testing instead of ic_kit::ic::spawn because the latter blocks the current thread - // waiting for the spawned futures to complete. - // This makes impossible to test concurrent behavior. - #[cfg(test)] - fn spawn>(future: F) { - tokio::task::spawn_local(future); - } - - #[cfg(not(test))] - #[inline(always)] - fn spawn>(future: F) { - ic_kit::ic::spawn(future); - } + // Process the tasks that are ready to be scheduled + for task_key in to_be_scheduled_tasks.iter() { + self.process_pending_task(*task_key, now_timestamp_secs); + } - /// Copy `tasks_being_processed` into `tasks_to_be_processed`, - /// - /// then keep in the hashset only the first half of the tasks. - async fn split_processing_tasks(&mut self) -> Result<()> { - // move tasks_being_processed to tasks_to_be_processed + // Remove the tasks that are out of time { - let mut tasks_running_lock = self.tasks_running.lock(); - let mut tasks_to_be_processed_lock = self.tasks_running.lock(); - tasks_to_be_processed_lock.clear(); - for key in tasks_running_lock.iter() { - tasks_to_be_processed_lock.insert(*key); - } - // clear tasks_running - tasks_running_lock.clear(); - drop(tasks_running_lock); - - // remove second half - let total_tasks_half = tasks_to_be_processed_lock.len() / 2; - let second_half: Vec = tasks_to_be_processed_lock - .iter() - .enumerate() - .filter_map(|(i, task)| { - if i > total_tasks_half { - Some(*task) - } else { - None + let mut lock = self.pending_tasks.lock(); + for task_key in out_of_time_tasks.into_iter() { + if let Some(mut task) = lock.remove(&task_key) { + task.status = TaskStatus::timeout_or_panic(now_timestamp_secs); + if let Some(cb) = &*self.on_completion_callback { + cb(task); } - }) - .collect(); - - for task in second_half { - tasks_to_be_processed_lock.remove(&task); + } } - drop(tasks_to_be_processed_lock); } - // save state - self.report_state(TaskExecutionState::Scheduled).await + Ok(to_be_scheduled_tasks.len()) } - /// Initialize tasks to be processed using the current timestamp and checking against the tasks which must - /// executed in this slot. - /// - /// Then reset processing tasks to an empty set - async fn init_tasks_to_be_processed(&mut self, timestamp: u64) -> Result<()> { - { - // save tasks to be executed at this time - let tasks_to_be_executed: Vec = self - .pending_tasks - .lock() - .iter() - .filter_map(|(key, task)| { - if task.options.execute_after_timestamp_in_secs <= timestamp { - Some(key) - } else { - None - } - }) - .collect(); + fn process_pending_task(&self, task_key: u32, now_timestamp_secs: u64) { + let task_scheduler = self.clone(); - // clear and insert tasks - let mut tasks_to_be_processed_lock = self.tasks_to_be_processed.lock(); - tasks_to_be_processed_lock.clear(); - for task in tasks_to_be_executed { - tasks_to_be_processed_lock.insert(task); + // Set the task as scheduled + { + let mut lock = task_scheduler.pending_tasks.lock(); + let task = lock.get(&task_key); + if let Some(mut task) = task { + if let TaskStatus::Waiting { .. } = task.status { + debug!( + "Scheduler - Task {} status changed: Waiting -> Scheduled", + task_key + ); + task.status = TaskStatus::scheduled(now_timestamp_secs); + lock.insert(&task_key, &task); + } } - drop(tasks_to_be_processed_lock); - - self.tasks_running.lock().clear(); } - // save state - self.report_state(TaskExecutionState::Scheduled).await - } - - /// Remove a task from `to_be_processed` and move it into the `processing` set. Then save the current task - async fn put_task_to_processing_tasks(&mut self, task: u32) -> Result<()> { - let task_removed = self.tasks_to_be_processed.lock().remove(&task); - if task_removed { - self.tasks_running.lock().insert(task); - // save state - self.report_state(TaskExecutionState::Executing(task)).await - } else { - Ok(()) - } - } + Self::spawn(async move { + let now_timestamp_secs = time_secs(); - /// Remove a task from the tasks queue and save the state marking it as completed - async fn remove_completed_task_from_processing_tasks(&mut self, task: u32) -> Result<()> { - self.remove_task_from_processing_tasks(task); - // save state - self.report_state(TaskExecutionState::Completed(task)).await - } + let task = task_scheduler.pending_tasks.lock().get(&task_key); + if let Some(mut task) = task { + if let TaskStatus::Scheduled { .. } = task.status { + debug!( + "Scheduler - Task {} status changed: Scheduled -> Running", + task_key + ); + task.status = TaskStatus::running(now_timestamp_secs); + task_scheduler.pending_tasks.lock().insert(&task_key, &task); - /// Remove a failed task (FAILED! NOT PANICKED) from the tasks queue and save the state marking it as completed - async fn remove_failed_task_from_processing_tasks( - &mut self, - task: u32, - error: SchedulerError, - ) -> Result<()> { - self.remove_task_from_processing_tasks(task); - - // save state - self.report_state(TaskExecutionState::Failed(task, error)) - .await - } + let completed_task = match task + .task + .execute(Box::new(task_scheduler.clone())) + .await + { + Ok(()) => { + debug!("Scheduler - Task {} execution succeeded. Status changed: Running -> Completed", task_key); + let mut lock = task_scheduler.pending_tasks.lock(); + let mut task = lock.remove(&task_key).unwrap(); + task.status = TaskStatus::completed(now_timestamp_secs); + Some(task) + } + Err(err) => { + let mut lock = task_scheduler.pending_tasks.lock(); + task.options.failures += 1; + let (should_retry, retry_delay) = task + .options + .retry_strategy + .should_retry(task.options.failures); + + if should_retry { + debug!("Scheduler - Task {} execution failed. Execution will be retried. Status changed: Running -> Waiting", task_key); + task.options.execute_after_timestamp_in_secs = + now_timestamp_secs + (retry_delay as u64); + task.status = TaskStatus::waiting(now_timestamp_secs); + lock.insert(&task_key, &task); + None + } else { + debug!("Scheduler - Task {} execution failed. Status changed: Running -> Failed", task_key); + let mut task = lock.remove(&task_key).unwrap(); + task.status = TaskStatus::failed(now_timestamp_secs, err); + Some(task) + } + } + }; - /// Remove a task from the tasks queue - fn remove_task_from_processing_tasks(&mut self, task: u32) { - // delete the task from tasks_running - self.tasks_running.lock().remove(&task); - // delete the task from pending_tasks - self.pending_tasks.lock().remove(&task); + if let Some(task) = completed_task { + if let Some(cb) = &*task_scheduler.on_completion_callback { + cb(task); + } + } + } + } + }); } - /// Remove a task from `tasks_running` and from `pending_tasks` - async fn delete_unprocessable_task(&mut self, task: u32) -> Result<()> { - // delete the task from tasks_running - self.tasks_running.lock().remove(&task); - // delete the task from pending_tasks - self.pending_tasks.lock().remove(&task); - // save state - self.report_state(TaskExecutionState::Panicked(task)).await + // We use tokio for testing instead of ic_kit::ic::spawn because the latter blocks the current thread + // waiting for the spawned futures to complete. + // This makes impossible to test concurrent behavior. + #[cfg(test)] + fn spawn>(future: F) { + tokio::task::spawn_local(future); } - /// Save the current state of the scheduler. - async fn report_state(&self, state: TaskExecutionState) -> Result<()> { - if let Some(ref on_execution_state_changed_callback) = - self.on_execution_state_changed_callback - { - (*on_execution_state_changed_callback)(state).await?; - } - - Ok(()) + #[cfg(not(test))] + #[inline(always)] + fn spawn>(future: F) { + ic_cdk_timers::set_timer(std::time::Duration::from_millis(0), || { + ic_kit::ic::spawn(future); + }); } } pub trait TaskScheduler { - fn append_task(&self, task: ScheduledTask); - fn append_tasks(&self, tasks: Vec>); + /// Append a task to the scheduler and return the key of the task. + fn append_task(&self, task: ScheduledTask) -> u32; + /// Append a list of tasks to the scheduler and return the keys of the tasks. + fn append_tasks(&self, tasks: Vec>) -> Vec; + /// Get a task by its key. + fn get_task(&self, task_id: u32) -> Option>; } -impl Clone for Scheduler -where - T: 'static + Task, - P: 'static + IterableUnboundedMapStructure>, +impl>> + Clone for Scheduler { fn clone(&self) -> Self { Self { pending_tasks: self.pending_tasks.clone(), phantom: self.phantom, - tasks_running: self.tasks_running.clone(), - tasks_to_be_processed: self.tasks_to_be_processed.clone(), - on_execution_state_changed_callback: self.on_execution_state_changed_callback.clone(), + on_completion_callback: self.on_completion_callback.clone(), + running_task_timeout_secs: AtomicU64::new( + self.running_task_timeout_secs.load(Ordering::Relaxed), + ), } } } -impl TaskScheduler for Scheduler -where - T: 'static + Task, - P: 'static + IterableUnboundedMapStructure>, +impl>> + TaskScheduler for Scheduler { - fn append_task(&self, task: ScheduledTask) { + fn append_task(&self, task: ScheduledTask) -> u32 { + let time_secs = time_secs(); let mut lock = self.pending_tasks.lock(); let key = lock.last_key().map(|val| val + 1).unwrap_or_default(); - lock.insert(&key, &task); + lock.insert( + &key, + &InnerScheduledTask::with_status( + key, + task, + TaskStatus::Waiting { + timestamp_secs: time_secs, + }, + ), + ); + key } - fn append_tasks(&self, tasks: Vec>) { + fn append_tasks(&self, tasks: Vec>) -> Vec { if tasks.is_empty() { - return; + return vec![]; }; + let time_secs = time_secs(); let mut lock = self.pending_tasks.lock(); let mut key = lock.last_key().map(|val| val + 1).unwrap_or_default(); + let mut keys = Vec::with_capacity(tasks.len()); for task in tasks { - lock.insert(&key, &task); + lock.insert( + &key, + &InnerScheduledTask::with_status( + key, + task, + TaskStatus::Waiting { + timestamp_secs: time_secs, + }, + ), + ); + keys.push(key); key += 1; } + keys + } + + fn get_task(&self, task_id: u32) -> Option> { + self.pending_tasks.lock().get(&task_id) } } @@ -363,14 +289,12 @@ mod test { use super::*; - type SaveStateCb = Pin>>>; - mod test_execution { - use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; + use std::sync::atomic::AtomicBool; use std::time::Duration; use ic_stable_structures::{StableUnboundedMap, VectorMemory}; @@ -394,7 +318,7 @@ mod test { fn execute( &self, task_scheduler: Box>, - ) -> Pin>>> { + ) -> Pin>>> { match self { SimpleTaskSteps::One { id } => { let id = *id; @@ -448,57 +372,20 @@ mod test { } } - thread_local! { - static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); - } - - #[tokio::test] - async fn test_should_call_report_state_cb() { - let scheduler = scheduler(); - - assert!(scheduler - .report_state(TaskExecutionState::Failed( - 1, - SchedulerError::TaskExecutionFailed("ciao".to_string()) - )) - .await - .is_ok()); - - REPORT_STATE_CB_CALLED.with_borrow(|state| { - assert!(matches!( - state.as_ref().unwrap(), - TaskExecutionState::Failed(1, _) - )); - }); - } - - async fn report_state(state: TaskExecutionState) -> Result<()> { - if let TaskExecutionState::Failed(id, err) = state { - REPORT_STATE_CB_CALLED.with(|called| { - called.replace(Some(TaskExecutionState::Failed(id, err))); - }); - } - - Ok(()) - } - - fn report_state_cb(state: TaskExecutionState) -> SaveStateCb { - Box::pin(async { report_state(state).await }) - } - #[tokio::test] async fn test_run_scheduler() { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let scheduler = Scheduler::new(map); let id = random(); scheduler.append_task(SimpleTaskSteps::One { id }.into()); let mut completed = false; while !completed { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; STATE.with(|state| { let state = state.lock(); @@ -522,19 +409,26 @@ mod test { } #[tokio::test] - async fn test_error_cb_not_called_in_case_of_success() { + async fn test_error_cb_called_on_success() { let local = tokio::task::LocalSet::new(); - + let called = Arc::new(AtomicBool::new(false)); + let called_t = called.clone(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let mut scheduler = Scheduler::new(map); + scheduler.on_completion_callback(move |task| { + if let TaskStatus::Completed { .. } = task.status { + called_t.store(true, std::sync::atomic::Ordering::SeqCst); + } + }); let id = random(); scheduler.append_task(SimpleTaskSteps::One { id }.into()); let mut completed = false; while !completed { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(100)).await; STATE.with(|state| { let state = state.lock(); @@ -556,39 +450,18 @@ mod test { }) .await; - REPORT_STATE_CB_CALLED.with_borrow(|state| { - assert!(state.is_none()); - }); - } - - type TestScheduler = Scheduler< - SimpleTaskSteps, - StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - >, - >; - - fn scheduler() -> TestScheduler { - let map: StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - > = StableUnboundedMap::new(VectorMemory::default()); - Scheduler::new(map, Some(Box::new(report_state_cb))).unwrap() + assert!(called.load(std::sync::atomic::Ordering::SeqCst)); } } mod test_delay { - use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::time::Duration; - use ic_stable_structures::{StableUnboundedMap, UnboundedMapStructure as _, VectorMemory}; + use ic_stable_structures::{StableUnboundedMap, UnboundedMapStructure, VectorMemory}; use rand::random; use serde::{Deserialize, Serialize}; @@ -596,9 +469,7 @@ mod test { use crate::task::TaskOptions; thread_local! { - pub static STATE: Mutex>> = Mutex::new(HashMap::new()); - - static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); + pub static STATE: Mutex>> = Mutex::new(HashMap::new()) } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -610,7 +481,7 @@ mod test { fn execute( &self, _task_scheduler: Box>, - ) -> Pin>>> { + ) -> Pin>>> { match self { SimpleTask::StepOne { id } => { let id = *id; @@ -634,7 +505,8 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let scheduler = Scheduler::new(map); let id = random(); let timestamp: u64 = random(); @@ -648,7 +520,7 @@ mod test { for i in 0..10 { // Should not run the task because the execution timestamp is in the future - scheduler.run_with_timestamp(timestamp + i).await.unwrap(); + scheduler.run_with_timestamp(timestamp + i).unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { let state = state.lock(); @@ -657,7 +529,7 @@ mod test { }); } - scheduler.run_with_timestamp(timestamp + 11).await.unwrap(); + scheduler.run_with_timestamp(timestamp + 11).unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { let state = state.lock(); @@ -668,49 +540,16 @@ mod test { }) .await; } - - type TestScheduler = Scheduler< - SimpleTask, - StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - >, - >; - - fn scheduler() -> TestScheduler { - let map: StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - > = StableUnboundedMap::new(VectorMemory::default()); - - Scheduler::new(map, Some(Box::new(report_state_cb))).unwrap() - } - - async fn report_state(state: TaskExecutionState) -> Result<()> { - if let TaskExecutionState::Failed(id, err) = state { - REPORT_STATE_CB_CALLED.with(|called| { - called.replace(Some(TaskExecutionState::Failed(id, err))); - }); - } - Ok(()) - } - - fn report_state_cb(state: TaskExecutionState) -> SaveStateCb { - Box::pin(async { report_state(state).await }) - } } mod test_failure_and_retry { - use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; use std::time::Duration; - use ic_stable_structures::{StableUnboundedMap, UnboundedMapStructure as _, VectorMemory}; + use ic_stable_structures::{StableUnboundedMap, UnboundedMapStructure, VectorMemory}; use rand::random; use serde::{Deserialize, Serialize}; @@ -725,8 +564,6 @@ mod test { thread_local! { static STATE: Mutex> = Mutex::new(HashMap::new()); - - static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -738,7 +575,7 @@ mod test { fn execute( &self, _task_scheduler: Box>, - ) -> Pin>>> { + ) -> Pin>>> { match self { SimpleTask::StepOne { id, fails } => { let id = *id; @@ -772,7 +609,8 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let scheduler = Scheduler::new(map); let id = random(); let fails = 10; let retries = 3; @@ -789,7 +627,7 @@ mod test { // beware that the the first execution is not a retry for i in 1..=retries { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { let state = state.lock(); @@ -807,7 +645,7 @@ mod test { } // After the last retries the task is removed - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { @@ -834,7 +672,8 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let scheduler = Scheduler::new(map); let id = random(); let fails = 2; let retries = 4; @@ -851,7 +690,7 @@ mod test { // beware that the the first execution is not a retry for _ in 1..=retries { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; } @@ -877,7 +716,8 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let scheduler = Scheduler::new(map); let id = random(); let fails = 10; let retries = 10; @@ -893,37 +733,33 @@ mod test { .into(), ); - let timestamp = random(); - assert_eq!(1, scheduler.run_with_timestamp(timestamp).await.unwrap()); + let timestamp = time_secs(); + assert_eq!(1, scheduler.run_with_timestamp(timestamp).unwrap()); tokio::time::sleep(Duration::from_millis(25)).await; { let pending_tasks = scheduler.pending_tasks.lock(); assert_eq!(pending_tasks.len(), 1); assert_eq!(pending_tasks.get(&0).unwrap().options.failures, 1); - assert_eq!( + assert!( pending_tasks .get(&0) .unwrap() .options - .execute_after_timestamp_in_secs, - timestamp + retry_delay_secs + .execute_after_timestamp_in_secs + >= timestamp + retry_delay_secs ); } // Should not run the task because the retry timestamp is in the future for i in 0..retry_delay_secs { - assert_eq!( - 0, - scheduler.run_with_timestamp(timestamp + i).await.unwrap() - ); + assert_eq!(0, scheduler.run_with_timestamp(timestamp + i).unwrap()); } assert_eq!( 1, scheduler .run_with_timestamp(timestamp + retry_delay_secs) - .await .unwrap() ); }) @@ -932,12 +768,23 @@ mod test { #[tokio::test] async fn test_should_call_error_cb() { + use std::sync::atomic::AtomicBool; + let local = tokio::task::LocalSet::new(); - let id = random(); + let called = Arc::new(AtomicBool::new(false)); + let called_t = called.clone(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let mut scheduler = Scheduler::new(map); + scheduler.on_completion_callback(move |task| { + if let TaskStatus::Failed { .. } = task.status { + called_t.store(true, std::sync::atomic::Ordering::SeqCst); + } + }); + + let id = random(); let fails = 10; scheduler.append_task( @@ -949,26 +796,32 @@ mod test { ); // beware that the the first execution is not a retry - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; let pending_tasks = scheduler.pending_tasks.lock(); assert_eq!(pending_tasks.len(), 0); }) .await; - REPORT_STATE_CB_CALLED.with_borrow(|state| { - assert!(matches!( - state.as_ref().unwrap(), - TaskExecutionState::Failed(_, _) - )) - }); + assert!(called.load(std::sync::atomic::Ordering::SeqCst)); } #[tokio::test] async fn test_should_not_call_error_cb_if_succeeds_after_retries() { + use std::sync::atomic::AtomicBool; + let local = tokio::task::LocalSet::new(); + let called = Arc::new(AtomicBool::new(false)); + let called_t = called.clone(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let mut scheduler = Scheduler::new(map); + + scheduler.on_completion_callback(move |task| { + if let TaskStatus::Completed { .. } = task.status { + called_t.store(true, std::sync::atomic::Ordering::SeqCst); + } + }); let id = random(); let fails = 2; @@ -986,7 +839,7 @@ mod test { // beware that the the first execution is not a retry for _ in 1..=retries { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; } @@ -1005,20 +858,26 @@ mod test { }); }) .await; - - REPORT_STATE_CB_CALLED.with_borrow(|state| { - assert!(state.is_none()); - }); + assert!(called.load(std::sync::atomic::Ordering::SeqCst)); } #[tokio::test] async fn test_should_call_error_only_after_retries() { + use std::sync::atomic::AtomicU8; + let local = tokio::task::LocalSet::new(); - let id = random(); + let called = Arc::new(AtomicU8::new(0)); + let called_t = called.clone(); local .run_until(async move { - let mut scheduler = scheduler(); + let map = StableUnboundedMap::new(VectorMemory::default()); + let mut scheduler = Scheduler::new(map); + scheduler.on_completion_callback(move |_| { + called_t.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }); + + let id = random(); let fails = 10; let retries = 3; @@ -1034,7 +893,7 @@ mod test { // beware that the the first execution is not a retry for i in 1..=retries { - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { let state = state.lock(); @@ -1048,18 +907,11 @@ mod test { }); let pending_tasks = scheduler.pending_tasks.lock(); assert_eq!(pending_tasks.len(), 1); - println!( - "{:?}", - pending_tasks - .iter() - .map(|(k, v)| (k, v.options.failures)) - .collect::>() - ); assert_eq!(pending_tasks.get(&0).unwrap().options.failures, i); } // After the last retries the task is removed - scheduler.run().await.unwrap(); + scheduler.run().unwrap(); tokio::time::sleep(Duration::from_millis(25)).await; STATE.with(|state| { @@ -1079,43 +931,7 @@ mod test { }); }) .await; - REPORT_STATE_CB_CALLED.with_borrow(|state| { - assert!(matches!( - state.as_ref().unwrap(), - TaskExecutionState::Failed(_, _) - )); - }); - } - - type TestScheduler = Scheduler< - SimpleTask, - StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - >, - >; - - fn scheduler() -> TestScheduler { - let map: StableUnboundedMap< - u32, - ScheduledTask, - std::rc::Rc>>, - > = StableUnboundedMap::new(VectorMemory::default()); - Scheduler::new(map, Some(Box::new(report_state_cb))).unwrap() - } - - async fn report_state(state: TaskExecutionState) -> Result<()> { - if let TaskExecutionState::Failed(id, err) = state { - REPORT_STATE_CB_CALLED.with(|called| { - called.replace(Some(TaskExecutionState::Failed(id, err))); - }); - } - Ok(()) - } - - fn report_state_cb(state: TaskExecutionState) -> SaveStateCb { - Box::pin(async { report_state(state).await }) + assert_eq!(called.load(std::sync::atomic::Ordering::SeqCst), 1); } } } diff --git a/ic-task-scheduler/src/task.rs b/ic-task-scheduler/src/task.rs index 55f3825c..95f0b2e0 100644 --- a/ic-task-scheduler/src/task.rs +++ b/ic-task-scheduler/src/task.rs @@ -1,6 +1,7 @@ use std::future::Future; use std::pin::Pin; +use candid::CandidType; use ic_stable_structures::{Bound, ChunkSize, SlicedStorable, Storable}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; @@ -10,7 +11,7 @@ use crate::scheduler::TaskScheduler; use crate::SchedulerError; /// A sync task is a unit of work that can be executed by the scheduler. -pub trait Task: Clone { +pub trait Task { /// Execute the task and return the next task to execute. fn execute( &self, @@ -19,7 +20,7 @@ pub trait Task: Clone { } /// A scheduled task is a task that is ready to be executed. -#[derive(Default, Clone, Serialize, Deserialize, PartialEq, Eq, Debug)] +#[derive(CandidType, Serialize, Deserialize, PartialEq, Eq, Debug)] pub struct ScheduledTask { pub(crate) task: T, pub(crate) options: TaskOptions, @@ -50,7 +51,47 @@ impl From<(T, TaskOptions)> for ScheduledTask { } } -impl Storable for ScheduledTask { +#[derive(CandidType, Serialize, Deserialize, PartialEq, Eq, Debug)] +pub struct InnerScheduledTask { + pub(crate) id: u32, + pub(crate) task: T, + pub(crate) options: TaskOptions, + pub(crate) status: TaskStatus, +} + +impl InnerScheduledTask { + /// Creates a new InnerScheduledTask with the given status + pub fn with_status(id: u32, task: ScheduledTask, status: TaskStatus) -> Self { + Self { + id, + task: task.task, + options: task.options, + status, + } + } + + /// Returs the status of the task + pub fn status(&self) -> &TaskStatus { + &self.status + } + + /// Returs the options of the task + pub fn options(&self) -> &TaskOptions { + &self.options + } + + /// Returs the task + pub fn task(&self) -> &T { + &self.task + } + + /// Returs the task id + pub fn id(&self) -> u32 { + self.id + } +} + +impl Storable for InnerScheduledTask { fn to_bytes(&self) -> std::borrow::Cow<[u8]> { bincode::serialize(self) .expect("failed to serialize ScheduledTask") @@ -64,12 +105,79 @@ impl Storable for ScheduledTas const BOUND: Bound = Bound::Unbounded; } -impl SlicedStorable for ScheduledTask { +impl SlicedStorable for InnerScheduledTask { const CHUNK_SIZE: ChunkSize = 128; } +/// The status of a task in the scheduler +#[derive(CandidType, Serialize, Deserialize, PartialEq, Eq, Debug)] +pub enum TaskStatus { + /// The task is waiting to be executed + Waiting { timestamp_secs: u64 }, + /// The task execution completed successfully + Completed { timestamp_secs: u64 }, + /// The task was scheduled to be runned + Scheduled { timestamp_secs: u64 }, + /// The task is running + Running { timestamp_secs: u64 }, + /// The task execution failed and no more retries are allowed + Failed { + timestamp_secs: u64, + error: SchedulerError, + }, + /// The task has been running for long time. It could be stuck or panicking + TimeoutOrPanic { timestamp_secs: u64 }, +} + +impl TaskStatus { + /// Creates a new TaskStatus::Waiting with the given timestamp in seconds + pub fn waiting(timestamp_secs: u64) -> Self { + Self::Waiting { timestamp_secs } + } + + /// Creates a new TaskStatus::Completed with the given timestamp in seconds + pub fn completed(timestamp_secs: u64) -> Self { + Self::Completed { timestamp_secs } + } + + /// Creates a new TaskStatus::Failed with the given timestamp in seconds and error + pub fn failed(timestamp_secs: u64, error: SchedulerError) -> Self { + Self::Failed { + timestamp_secs, + error, + } + } + + /// Creates a new TaskStatus::Running with the given timestamp in seconds + pub fn running(timestamp_secs: u64) -> Self { + Self::Running { timestamp_secs } + } + + /// Creates a new TaskStatus::Scheduled with the given timestamp in seconds + pub fn scheduled(timestamp_secs: u64) -> Self { + Self::Scheduled { timestamp_secs } + } + + /// Creates a new TaskStatus::TimeoutOrPanic with the given timestamp in seconds + pub fn timeout_or_panic(timestamp_secs: u64) -> Self { + Self::TimeoutOrPanic { timestamp_secs } + } + + /// Returns the timestamp of the status + pub fn timestamp_secs(&self) -> u64 { + match self { + TaskStatus::Waiting { timestamp_secs } => *timestamp_secs, + TaskStatus::Completed { timestamp_secs } => *timestamp_secs, + TaskStatus::Running { timestamp_secs } => *timestamp_secs, + TaskStatus::TimeoutOrPanic { timestamp_secs } => *timestamp_secs, + TaskStatus::Failed { timestamp_secs, .. } => *timestamp_secs, + TaskStatus::Scheduled { timestamp_secs, .. } => *timestamp_secs, + } + } +} + /// Scheduling options for a task -#[derive(Default, Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] +#[derive(CandidType, Default, Serialize, Deserialize, PartialEq, Eq, Debug)] pub struct TaskOptions { pub(crate) failures: u32, pub(crate) execute_after_timestamp_in_secs: u64, @@ -120,7 +228,7 @@ mod test { use super::*; - #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] + #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] struct TestTask {} impl Task for TestTask { @@ -135,62 +243,74 @@ mod test { #[test] fn test_storable_task() { { - let task = ScheduledTask::with_options( - TestTask {}, - TaskOptions::new() + let task = InnerScheduledTask { + id: 0, + task: TestTask {}, + options: TaskOptions::new() .with_max_retries_policy(3) .with_fixed_backoff_policy(2), - ); + status: TaskStatus::Waiting { timestamp_secs: 0 }, + }; let serialized = task.to_bytes(); - let deserialized = ScheduledTask::::from_bytes(serialized); + let deserialized = InnerScheduledTask::::from_bytes(serialized); assert_eq!(task, deserialized); } { - let task = ScheduledTask::with_options( - TestTask {}, - TaskOptions::new() + let task = InnerScheduledTask { + id: 0, + task: TestTask {}, + options: TaskOptions::new() .with_retry_policy(RetryPolicy::None) .with_backoff_policy(BackoffPolicy::None), - ); + status: TaskStatus::Waiting { timestamp_secs: 0 }, + }; let serialized = task.to_bytes(); - let deserialized = ScheduledTask::::from_bytes(serialized); + let deserialized = InnerScheduledTask::::from_bytes(serialized); assert_eq!(task, deserialized); } { - let task = ScheduledTask::with_options( - TestTask {}, - TaskOptions::new() + let task = InnerScheduledTask { + id: 0, + task: TestTask {}, + options: TaskOptions::new() .with_retry_policy(RetryPolicy::None) .with_backoff_policy(BackoffPolicy::Exponential { secs: 2, multiplier: 2, }), - ); + status: TaskStatus::Completed { + timestamp_secs: 1230, + }, + }; let serialized = task.to_bytes(); - let deserialized = ScheduledTask::::from_bytes(serialized); + let deserialized = InnerScheduledTask::::from_bytes(serialized); assert_eq!(task, deserialized); } { - let task = ScheduledTask::with_options( - TestTask {}, - TaskOptions::new() + let task = InnerScheduledTask { + id: 0, + task: TestTask {}, + options: TaskOptions::new() .with_retry_policy(RetryPolicy::Infinite) .with_backoff_policy(BackoffPolicy::Variable { secs: vec![12, 56, 76], }), - ); + status: TaskStatus::Running { + timestamp_secs: 21230, + }, + }; let serialized = task.to_bytes(); - let deserialized = ScheduledTask::::from_bytes(serialized); + let deserialized = InnerScheduledTask::::from_bytes(serialized); assert_eq!(task, deserialized); } diff --git a/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs b/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs index 3702edcc..8f5c4aa7 100644 --- a/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs +++ b/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs @@ -3,16 +3,17 @@ use std::future::Future; use std::pin::Pin; use std::time::Duration; -use candid::Principal; +use candid::{CandidType, Principal}; use ic_canister::{generate_idl, init, post_upgrade, query, update, Canister, Idl, PreUpdate}; use ic_stable_structures::stable_structures::DefaultMemoryImpl; use ic_stable_structures::{IcMemoryManager, MemoryId, StableUnboundedMap, VirtualMemory}; -use ic_task_scheduler::scheduler::{Scheduler, TaskExecutionState, TaskScheduler}; -use ic_task_scheduler::task::{ScheduledTask, Task, TaskOptions}; +use ic_task_scheduler::scheduler::{Scheduler, TaskScheduler}; +use ic_task_scheduler::task::{InnerScheduledTask, ScheduledTask, Task, TaskStatus}; use ic_task_scheduler::SchedulerError; use serde::{Deserialize, Serialize}; -type Storage = StableUnboundedMap, VirtualMemory>; +type Storage = + StableUnboundedMap, VirtualMemory>; type PanickingScheduler = Scheduler; const SCHEDULER_STORAGE_MEMORY_ID: MemoryId = MemoryId::new(1); @@ -23,30 +24,25 @@ thread_local! { static SCHEDULER: RefCell = { let map: Storage = Storage::new(MEMORY_MANAGER.with(|mm| mm.get(SCHEDULER_STORAGE_MEMORY_ID))); - let scheduler = PanickingScheduler::new( + let mut scheduler = PanickingScheduler::new( map, - Some(Box::new(save_state_cb)) - ).unwrap(); + ); - scheduler.append_task((DummyTask::GoodTask, TaskOptions::new()).into()); - scheduler.append_task((DummyTask::Panicking, TaskOptions::new()).into()); - scheduler.append_task((DummyTask::GoodTask, TaskOptions::new()).into()); - scheduler.append_task((DummyTask::FailTask, TaskOptions::new()).into()); + scheduler.set_running_task_timeout(30); + scheduler.on_completion_callback(save_state_cb); RefCell::new(scheduler) }; - static SCHEDULED_STATE_CALLED: RefCell = RefCell::new(false); static COMPLETED_TASKS: RefCell> = RefCell::new(vec![]); static FAILED_TASKS: RefCell> = RefCell::new(vec![]); static PANICKED_TASKS : RefCell> = RefCell::new(vec![]); - static EXECUTING_TASKS : RefCell> = RefCell::new(vec![]); static PRINCIPAL : RefCell = RefCell::new(Principal::anonymous()); } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(CandidType, Serialize, Deserialize, Debug, Clone)] pub enum DummyTask { Panicking, GoodTask, @@ -98,24 +94,10 @@ impl DummyCanister { fn set_timers(&self) { ic_exports::ic_cdk_timers::set_timer_interval(Duration::from_millis(10), || { - ic_cdk::spawn(Self::do_run_scheduler()) + Self::do_run_scheduler() }); } - #[update] - pub fn save_state(&self) -> bool { - SCHEDULED_STATE_CALLED.with_borrow_mut(|called| { - *called = true; - }); - - true - } - - #[query] - pub fn scheduled_state_called(&self) -> bool { - SCHEDULED_STATE_CALLED.with_borrow(|called| *called) - } - #[query] pub fn panicked_tasks(&self) -> Vec { PANICKED_TASKS.with_borrow(|tasks| tasks.clone()) @@ -132,18 +114,26 @@ impl DummyCanister { } #[query] - pub fn executed_tasks(&self) -> Vec { - EXECUTING_TASKS.with_borrow(|tasks| tasks.clone()) + pub fn get_task(&self, task_id: u32) -> Option> { + let scheduler = SCHEDULER.with_borrow(|scheduler| scheduler.clone()); + scheduler.get_task(task_id) + } + + #[update] + pub fn schedule_tasks(&self, tasks: Vec) -> Vec { + let scheduler = SCHEDULER.with_borrow(|scheduler| scheduler.clone()); + let scheduled_tasks = tasks.into_iter().map(ScheduledTask::new).collect(); + scheduler.append_tasks(scheduled_tasks) } #[update] - pub async fn run_scheduler(&self) { - Self::do_run_scheduler().await + pub fn run_scheduler(&self) { + Self::do_run_scheduler(); } - async fn do_run_scheduler() { - let mut scheduler = SCHEDULER.with_borrow(|scheduler| scheduler.clone()); - scheduler.run().await.unwrap(); + fn do_run_scheduler() { + let scheduler = SCHEDULER.with_borrow(|scheduler| scheduler.clone()); + scheduler.run().unwrap(); } pub fn idl() -> Idl { @@ -151,38 +141,25 @@ impl DummyCanister { } } -async fn save_state(state: TaskExecutionState) -> ic_task_scheduler::Result<()> { - let canister = PRINCIPAL.with_borrow(|principal| *principal); - match state { - TaskExecutionState::Completed(id) => { +fn save_state_cb(task: InnerScheduledTask) { + match task.status() { + TaskStatus::Waiting { .. } => {} + TaskStatus::Completed { .. } => { COMPLETED_TASKS.with_borrow_mut(|tasks| { - tasks.push(id); + tasks.push(task.id()); }); } - TaskExecutionState::Panicked(id) => { - PANICKED_TASKS.with_borrow_mut(|tasks| { - tasks.push(id); - }); - } - TaskExecutionState::Failed(id, _) => { + TaskStatus::Running { .. } => {} + TaskStatus::Failed { .. } => { FAILED_TASKS.with_borrow_mut(|tasks| { - tasks.push(id); + tasks.push(task.id()); }); } - TaskExecutionState::Executing(id) => { - EXECUTING_TASKS.with_borrow_mut(|tasks| { - tasks.push(id); + TaskStatus::TimeoutOrPanic { .. } => { + PANICKED_TASKS.with_borrow_mut(|tasks| { + tasks.push(task.id()); }); } - TaskExecutionState::Scheduled => {} - } - ic_exports::ic_cdk::call(canister, "save_state", ()) - .await - .map_err(|(_, msg)| ic_task_scheduler::SchedulerError::TaskExecutionFailed(msg)) -} - -type SaveStateCb = Pin>>>; - -fn save_state_cb(state: TaskExecutionState) -> SaveStateCb { - Box::pin(async { save_state(state).await }) + TaskStatus::Scheduled { .. } => {} + }; } diff --git a/ic-task-scheduler/tests/pocket_ic_tests/mod.rs b/ic-task-scheduler/tests/pocket_ic_tests/mod.rs index b00ebc91..54b620e2 100644 --- a/ic-task-scheduler/tests/pocket_ic_tests/mod.rs +++ b/ic-task-scheduler/tests/pocket_ic_tests/mod.rs @@ -1,18 +1,23 @@ mod scheduler; mod wasm_utils; +use std::future::Future; +use std::pin::Pin; use std::time::Duration; -use candid::{CandidType, Decode, Encode, Principal}; -use ic_exports::pocket_ic; +use candid::{CandidType, Encode, Principal}; +use ic_canister_client::PocketIcClient; use ic_exports::pocket_ic::nio::PocketIcAsync; use ic_kit::mock_principals::alice; -use pocket_ic::WasmResult; -use serde::Deserialize; +use ic_task_scheduler::scheduler::TaskScheduler; +use ic_task_scheduler::task::{InnerScheduledTask, Task}; +use ic_task_scheduler::SchedulerError; +use serde::{Deserialize, Serialize}; use wasm_utils::get_dummy_scheduler_canister_bytecode; #[derive(Clone)] pub struct PocketIcTestContext { + canister_client: PocketIcClient, client: PocketIcAsync, pub dummy_scheduler_canister: Principal, } @@ -23,77 +28,39 @@ impl PocketIcTestContext { &self.client } - async fn query_as( - &self, - sender: Principal, - canister_id: Principal, - method: &str, - payload: Vec, - ) -> Result - where - for<'a> Result: CandidType + Deserialize<'a>, - { - let res = match self - .client - .query_call(canister_id, sender, method.to_string(), payload) + pub async fn get_task(&self, task_id: u32) -> Option> { + self.canister_client + .query("get_task", (task_id,)) .await .unwrap() - { - WasmResult::Reply(bytes) => bytes, - WasmResult::Reject(e) => panic!("Unexpected reject: {:?}", e), - }; - - Decode!(&res, Result).expect("failed to decode item from candid") - } - - pub async fn scheduled_state_called(&self) -> bool { - let args = Encode!(&()).unwrap(); - self.query_as( - alice(), - self.dummy_scheduler_canister, - "scheduled_state_called", - args, - ) - .await } pub async fn completed_tasks(&self) -> Vec { - let args = Encode!(&()).unwrap(); - self.query_as( - alice(), - self.dummy_scheduler_canister, - "completed_tasks", - args, - ) - .await + self.canister_client + .query("completed_tasks", ()) + .await + .unwrap() } pub async fn panicked_tasks(&self) -> Vec { - let args = Encode!(&()).unwrap(); - self.query_as( - alice(), - self.dummy_scheduler_canister, - "panicked_tasks", - args, - ) - .await + self.canister_client + .query("panicked_tasks", ()) + .await + .unwrap() } pub async fn failed_tasks(&self) -> Vec { - let args = Encode!(&()).unwrap(); - self.query_as(alice(), self.dummy_scheduler_canister, "failed_tasks", args) + self.canister_client + .query("failed_tasks", ()) .await + .unwrap() } - pub async fn executed_tasks(&self) -> Vec { - let args = Encode!(&()).unwrap(); - self.query_as( - alice(), - self.dummy_scheduler_canister, - "executed_tasks", - args, - ) - .await + pub async fn schedule_tasks(&self, tasks: Vec) -> Vec { + self.canister_client + .update("schedule_tasks", (tasks,)) + .await + .unwrap() } pub async fn run_scheduler(&self) { @@ -104,21 +71,26 @@ impl PocketIcTestContext { async fn deploy_dummy_scheduler_canister() -> anyhow::Result { let client = PocketIcAsync::init().await; - let dummy_wasm = get_dummy_scheduler_canister_bytecode(); println!("Creating dummy canister"); - let args = Encode!(&())?; - let sender = alice(); let canister = client.create_canister(Some(sender)).await; println!("Canister created with principal {}", canister); + + let canister_client = + ic_canister_client::PocketIcClient::from_client(client.clone(), canister, alice()); + let env = PocketIcTestContext { + canister_client, client, dummy_scheduler_canister: canister, }; - env.client().add_cycles(canister, 10_u128.pow(12)).await; + env.client().add_cycles(canister, 10_u128.pow(14)).await; println!("cycles added"); + + let dummy_wasm = get_dummy_scheduler_canister_bytecode(); + let args = Encode!(&())?; env.client() .install_canister(canister, dummy_wasm.to_vec(), args, Some(sender)) .await; @@ -127,3 +99,19 @@ async fn deploy_dummy_scheduler_canister() -> anyhow::Result>, + ) -> Pin>>> { + Box::pin(async move { Ok(()) }) + } +} diff --git a/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs b/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs index 278d82ab..de0c848b 100644 --- a/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs +++ b/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs @@ -1,6 +1,9 @@ +use std::collections::BTreeMap; + use candid::Principal; +use rand::Rng; -use crate::pocket_ic_tests::deploy_dummy_scheduler_canister; +use crate::pocket_ic_tests::{deploy_dummy_scheduler_canister, DummyTask}; thread_local! { static CANISTER: std::cell::RefCell = std::cell::RefCell::new(Principal::anonymous()); @@ -8,16 +11,91 @@ thread_local! { #[tokio::test] async fn test_should_remove_panicking_task() { + // Arrange let test_ctx = deploy_dummy_scheduler_canister().await.unwrap(); CANISTER.with_borrow_mut(|principal| *principal = test_ctx.dummy_scheduler_canister); + let mut tasks = vec![ + DummyTask::GoodTask, + DummyTask::FailTask, + DummyTask::Panicking, + DummyTask::GoodTask, + DummyTask::GoodTask, + DummyTask::GoodTask, + DummyTask::GoodTask, + DummyTask::GoodTask, + DummyTask::GoodTask, + DummyTask::Panicking, + DummyTask::Panicking, + DummyTask::Panicking, + DummyTask::FailTask, + DummyTask::FailTask, + DummyTask::FailTask, + ]; + + for _ in 0..1000 { + // Append a randomly selected task to the tasks vector + let task = match rand::thread_rng().gen_range(0..=2) { + // rand 0.8 + 0 => DummyTask::Panicking, + 1 => DummyTask::FailTask, + _ => DummyTask::GoodTask, + }; + tasks.push(task); + } + + let task_ids = test_ctx.schedule_tasks(tasks.clone()).await; + + let tasks_map: BTreeMap = task_ids + .into_iter() + .enumerate() + .map(|(id, key)| (key, tasks[id])) + .collect::>(); + assert_eq!(tasks.len(), tasks_map.len()); + + // Act for _ in 0..10 { test_ctx.run_scheduler().await; + println!("Get task 0: {:?}", test_ctx.get_task(0).await); + println!("Get task 1: {:?}", test_ctx.get_task(1).await); + println!("Get task 2: {:?}", test_ctx.get_task(2).await); } - assert!(test_ctx.scheduled_state_called().await); - assert_eq!(test_ctx.executed_tasks().await, vec![3, 0, 1, 2]); - assert_eq!(test_ctx.panicked_tasks().await, vec![1]); - assert_eq!(test_ctx.completed_tasks().await, vec![0, 2]); - assert_eq!(test_ctx.failed_tasks().await, vec![3]); + // Assert + let panicked_tasks = test_ctx.panicked_tasks().await; + let completed_tasks = test_ctx.completed_tasks().await; + let failed_tasks = test_ctx.failed_tasks().await; + + assert_eq!( + panicked_tasks.len() + completed_tasks.len() + failed_tasks.len(), + tasks_map.len() + ); + + compare(panicked_tasks, &tasks_map, DummyTask::Panicking); + compare(completed_tasks, &tasks_map, DummyTask::GoodTask); + compare(failed_tasks, &tasks_map, DummyTask::FailTask); +} + +fn compare(mut found: Vec, tasks_map: &BTreeMap, expected_task: DummyTask) { + let mut expected = tasks_map + .iter() + .filter(|(_, task)| task == &&expected_task) + .map(|(id, _)| *id) + .collect::>(); + + assert_eq!( + expected.len(), + found.len(), + "Task: {:?}, Expected: {:?}, Found: {:?}", + expected_task, + expected, + found + ); + expected.sort(); + found.sort(); + assert_eq!( + expected, found, + "Task: {:?}, Expected: {:?}, Found: {:?}", + expected_task, expected, found + ); }