From 432d2544a48803e07cda285e3b4a3871b60dd8f4 Mon Sep 17 00:00:00 2001 From: veeso Date: Mon, 22 Jan 2024 15:33:06 +0100 Subject: [PATCH] fix: there's still something off --- ic-task-scheduler/src/scheduler.rs | 504 +++++++++++------- .../dummy_scheduler_canister/src/canister.rs | 175 ++++-- .../tests/pocket_ic_tests/mod.rs | 57 +- .../tests/pocket_ic_tests/scheduler.rs | 15 +- 4 files changed, 494 insertions(+), 257 deletions(-) diff --git a/ic-task-scheduler/src/scheduler.rs b/ic-task-scheduler/src/scheduler.rs index fc0b9893..848c65a4 100644 --- a/ic-task-scheduler/src/scheduler.rs +++ b/ic-task-scheduler/src/scheduler.rs @@ -1,135 +1,165 @@ -use std::borrow::Borrow; use std::cell::RefCell; +use std::collections::HashSet; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use ic_kit::RejectionCode; -use ic_stable_structures::stable_structures::Memory; -use ic_stable_structures::{ - BTreeMapStructure, MemoryId, MemoryManager, StableBTreeMap, UnboundedMapStructure, -}; +use ic_stable_structures::UnboundedMapStructure; use parking_lot::Mutex; use crate::task::{ScheduledTask, Task}; use crate::time::time_secs; use crate::{Result, SchedulerError}; -type SchedulerErrorCallback = Box, SchedulerError) + Send>; -type SaveStateQueryCallback = dyn Fn() -> Pin>>> +thread_local! { + /// Queue containing tasks to be processed + static TASKS_TO_BE_PROCESSED: RefCell> = RefCell::new(HashSet::new()); + /// Tasks which are currently being processed + static TASKS_RUNNING: RefCell> = RefCell::new(HashSet::new()); +} + +/// The state of a task execution. +/// This is reported when the `SaveStateQueryCallback` is called. +#[derive(Debug)] +pub enum TaskExecutionState { + /// Reported when tasks to be executed are scheduled + Scheduled, + /// Reported when a task fails. + Failed(u32, SchedulerError), + /// Reported when a task starts executing. + Executing(u32), + /// Reported when a task completes successfully. + Completed(u32), + /// Reported when a task panics. + Panicked(u32), +} + +type SaveStateQueryCallback = dyn Fn( + TaskExecutionState, + ) -> Pin>>> + Send + Sync; + /// A scheduler is responsible for executing tasks. -pub struct Scheduler +pub struct Scheduler where T: 'static + Task, P: 'static + UnboundedMapStructure>, - M: 'static + Memory, { pending_tasks: Arc>, phantom: std::marker::PhantomData, - /// Queue containing the tasks to be executed in a loop - tasks_queue: Arc>>, - /// Callback to be called when a task fails. - failed_task_callback: Arc>>, /// Callback to be called to save the current canister state to prevent panicking tasks. - save_state_query_callback: Arc>>, + on_execution_state_changed_callback: Arc>, } -impl Scheduler +impl Scheduler where T: 'static + Task, P: 'static + UnboundedMapStructure>, - M: 'static + Memory, { /// Create a new scheduler. + /// + /// A callback must be passed and it'll be called to save and report the current state of the scheduler. + /// + /// ATTENTION! In order to prevent scheduler panic this callback should always make a query to a canister endpoint. + /// This because the current scheduler state should be saved while running tasks pub fn new( pending_tasks: P, - memory_manager: &dyn MemoryManager, - task_queue_memory_id: MemoryId, + on_execution_state_changed_callback: Box, ) -> Result { Ok(Self { pending_tasks: Arc::new(Mutex::new(pending_tasks)), phantom: std::marker::PhantomData, - tasks_queue: Arc::new(RefCell::new(StableBTreeMap::new( - memory_manager.get(task_queue_memory_id), - ))), - failed_task_callback: Arc::new(None), - save_state_query_callback: Arc::new(None), + on_execution_state_changed_callback: Arc::new(on_execution_state_changed_callback), }) } - /// Set a callback to be called when a task fails. - pub fn set_failed_task_callback, SchedulerError)>( - &mut self, - cb: F, - ) { - self.failed_task_callback = Arc::new(Some(Box::new(cb))); - } - - /// Set a callback to be called to save the current canister state to prevent panicking tasks. - pub fn set_save_state_query_callback(&mut self, cb: Box) { - self.save_state_query_callback = Arc::new(Some(cb)); - } - /// Execute all pending tasks. /// Each task is executed asynchronously in a dedicated ic_cdk::spawn call. /// This function does not wait for the tasks to complete. /// Returns the number of tasks that have been launched. - pub async fn run(&self) -> Result { + pub async fn run(&mut self) -> Result { self.run_with_timestamp(time_secs()).await } - async fn run_with_timestamp(&self, now_timestamp_secs: u64) -> Result { + async fn run_with_timestamp(&mut self, now_timestamp_secs: u64) -> Result { let mut to_be_reprocessed = Vec::new(); let mut task_execution_started = 0; - // delete unprocessed tasks and push tasks to execute - self.delete_unprocessed_tasks()?; - self.push_tasks_to_execute_to_queue(now_timestamp_secs) - .await?; + let tasks_running_count = TASKS_RUNNING.with_borrow(|tasks| tasks.len()); - { - let mut lock = self.pending_tasks.lock(); - while let Some(key) = lock.first_key() { - let task = lock.remove(&key); - drop(lock); - if let Some(mut task) = task { - if task.options.execute_after_timestamp_in_secs > now_timestamp_secs { - to_be_reprocessed.push(task); + match tasks_running_count { + 0 => { + // if there are no processing tasks, initialize the tasks to be processed + self.init_tasks_to_be_processed(now_timestamp_secs).await?; + } + 1 => { + // if there are processing tasks with length 1, + // delete that task and mark it as panicked + let task = TASKS_RUNNING.with_borrow(|tasks| *tasks.iter().next().unwrap()); + self.delete_unprocessable_task(task).await?; + + // eventually reschedule the tasks to be processed + self.init_tasks_to_be_processed(now_timestamp_secs).await?; + } + _ => { + // if there is more than one task to be processed, keep only half of them + self.split_processing_tasks().await?; + } + } + + let tasks_to_be_processed = TASKS_TO_BE_PROCESSED.with_borrow(|tasks| tasks.clone()); + + for task_id in tasks_to_be_processed { + let lock = self.pending_tasks.lock(); + let mut task = match lock.get(&task_id) { + Some(task) => task, + None => continue, + }; + drop(lock); + + if task.options.execute_after_timestamp_in_secs > now_timestamp_secs { + to_be_reprocessed.push(task); + continue; + } + + task_execution_started += 1; + let key = task_id; + let mut task_scheduler = self.clone(); + Self::spawn(async move { + task_scheduler + .put_task_to_processing_tasks(key) + .await + .unwrap(); + if let Err(err) = task.task.execute(Box::new(task_scheduler.clone())).await { + task.options.failures += 1; + let (should_retry, retry_delay) = task + .options + .retry_strategy + .should_retry(task.options.failures); + if should_retry { + task.options.execute_after_timestamp_in_secs = + now_timestamp_secs + (retry_delay as u64); + task_scheduler.append_task(task.clone()) } else { - task_execution_started += 1; - let task_scheduler = self.clone(); - Self::spawn(async move { - if let Err(err) = - task.task.execute(Box::new(task_scheduler.clone())).await - { - task.options.failures += 1; - let (should_retry, retry_delay) = task - .options - .retry_strategy - .should_retry(task.options.failures); - if should_retry { - task.options.execute_after_timestamp_in_secs = - now_timestamp_secs + (retry_delay as u64); - task_scheduler.append_task(task.clone()) - } else if let Some(cb) = &*task_scheduler.failed_task_callback { - cb(task.clone(), err); - } - } - // remove task from queue - if let (Some(cb), Err(err)) = ( - &*task_scheduler.failed_task_callback, - task_scheduler.remove_task_from_queue(key).await, - ) { - cb(task, err); - } - }); + // report error + task_scheduler + .remove_failed_task_from_processing_tasks(key, err) + .await + .unwrap(); + + return; } } - lock = self.pending_tasks.lock(); - } + // remove task from queue + task_scheduler + .remove_task_from_processing_tasks(key) + .await + .unwrap(); + }); } + self.append_tasks(to_be_reprocessed); Ok(task_execution_started) } @@ -148,8 +178,48 @@ where ic_kit::ic::spawn(future); } - /// Append a task to the tasks queue and save the state - async fn push_tasks_to_execute_to_queue(&self, timestamp: u64) -> Result<()> { + /// Copy `tasks_being_processed` into `tasks_to_be_processed`, + /// + /// then keep in the hashset only the first half of the tasks. + async fn split_processing_tasks(&mut self) -> Result<()> { + // move tasks_being_processed to tasks_to_be_processed + let tasks_to_be_processed = TASKS_RUNNING.with_borrow_mut(|tasks| std::mem::take(tasks)); + let tasks_len = TASKS_TO_BE_PROCESSED.with_borrow_mut(|tasks| { + *tasks = tasks_to_be_processed; + tasks.len() + }); + + let total_tasks_half = tasks_len / 2; + + let second_half: Vec = TASKS_TO_BE_PROCESSED.with_borrow(|tasks| { + tasks + .iter() + .enumerate() + .filter_map(|(i, task)| { + if i > total_tasks_half { + Some(*task) + } else { + None + } + }) + .collect() + }); + + TASKS_TO_BE_PROCESSED.with_borrow_mut(|tasks| { + for task in second_half { + tasks.remove(&task); + } + }); + + // save state + self.report_state(TaskExecutionState::Scheduled).await + } + + /// Initialize tasks to be processed using the current timestamp and checking against the tasks which must + /// executed in this slot. + /// + /// Then reset processing tasks to an empty set + async fn init_tasks_to_be_processed(&mut self, timestamp: u64) -> Result<()> { // save tasks to be executed at this time let tasks_to_be_executed: Vec = self .pending_tasks @@ -164,40 +234,76 @@ where }) .collect(); - let mut lock = self.tasks_queue.borrow_mut(); - for task in tasks_to_be_executed { - lock.insert(task, ()); + // clear and insert tasks + TASKS_TO_BE_PROCESSED.with_borrow_mut(|tasks| { + tasks.clear(); + for task in tasks_to_be_executed { + tasks.insert(task); + } + }); + TASKS_RUNNING.with_borrow_mut(|tasks| tasks.clear()); + // save state + self.report_state(TaskExecutionState::Scheduled).await + } + + /// Remove a task from `to_be_processed` and move it into the `processing` set. Then save the current task + async fn put_task_to_processing_tasks(&mut self, task: u32) -> Result<()> { + let task_removed = TASKS_TO_BE_PROCESSED.with_borrow_mut(|tasks| tasks.remove(&task)); + if task_removed { + TASKS_RUNNING.with_borrow_mut(|tasks| { + tasks.insert(task); + }); + // save state + self.report_state(TaskExecutionState::Executing(task)).await + } else { + Ok(()) } + } + + /// Remove a task from the tasks queue and save the state marking it as completed + async fn remove_task_from_processing_tasks(&mut self, task: u32) -> Result<()> { + TASKS_RUNNING.with_borrow_mut(|tasks| { + tasks.remove(&task); + }); + let mut lock = self.pending_tasks.lock(); + lock.remove(&task); // save state - self.save_state().await + self.report_state(TaskExecutionState::Completed(task)).await } - /// Remove a task from the tasks queue and save the state - async fn remove_task_from_queue(&self, task: u32) -> Result<()> { - let mut lock = self.tasks_queue.borrow_mut(); + /// Remove a failed task (FAILED! NOT PANICKED) from the tasks queue and save the state marking it as completed + async fn remove_failed_task_from_processing_tasks( + &mut self, + task: u32, + error: SchedulerError, + ) -> Result<()> { + TASKS_RUNNING.with_borrow_mut(|tasks| { + tasks.remove(&task); + }); + + let mut lock = self.pending_tasks.lock(); lock.remove(&task); + // save state - self.save_state().await + self.report_state(TaskExecutionState::Failed(task, error)) + .await } - /// Save the current state of the scheduler. - async fn save_state(&self) -> Result<()> { - if let Some(cb) = &*self.save_state_query_callback { - cb().await?; - } - Ok(()) + /// Remove a task from `tasks_running` and from `pending_tasks` + async fn delete_unprocessable_task(&mut self, task: u32) -> Result<()> { + // delete the task from tasks_running + TASKS_RUNNING.with_borrow_mut(|tasks| { + tasks.remove(&task); + }); + // delete the task from pending_tasks + self.pending_tasks.lock().remove(&task); + // save state + self.report_state(TaskExecutionState::Panicked(task)).await } - /// Remove all the tasks in `pending_tasks` which are contained in `tasks_queue`. - /// The also clear the values in `tasks_queue` - fn delete_unprocessed_tasks(&self) -> Result<()> { - // delete the tasks in the pending tasks - let queue: &RefCell> = self.tasks_queue.borrow(); - for (task, _) in queue.borrow().iter() { - self.pending_tasks.lock().remove(&task); - } - // empty the task queue - self.tasks_queue.borrow_mut().clear(); + /// Save the current state of the scheduler. + async fn report_state(&self, state: TaskExecutionState) -> Result<()> { + (&*self.on_execution_state_changed_callback)(state).await?; Ok(()) } @@ -208,28 +314,24 @@ pub trait TaskScheduler { fn append_tasks(&self, tasks: Vec>); } -impl Clone for Scheduler +impl Clone for Scheduler where T: 'static + Task, P: 'static + UnboundedMapStructure>, - M: Memory, { fn clone(&self) -> Self { Self { pending_tasks: self.pending_tasks.clone(), phantom: self.phantom, - failed_task_callback: self.failed_task_callback.clone(), - save_state_query_callback: self.save_state_query_callback.clone(), - tasks_queue: self.tasks_queue.clone(), + on_execution_state_changed_callback: self.on_execution_state_changed_callback.clone(), } } } -impl TaskScheduler for Scheduler +impl TaskScheduler for Scheduler where T: 'static + Task, P: 'static + UnboundedMapStructure>, - M: Memory, { fn append_task(&self, task: ScheduledTask) { let mut lock = self.pending_tasks.lock(); @@ -255,17 +357,14 @@ where #[cfg(test)] mod test { - use ic_stable_structures::stable_structures::DefaultMemoryImpl; - use ic_stable_structures::{default_ic_memory_manager, VirtualMemory}; - use super::*; mod test_execution { + use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; - use std::sync::atomic::AtomicBool; use std::time::Duration; use ic_stable_structures::{StableUnboundedMap, VectorMemory}; @@ -344,33 +443,46 @@ mod test { } thread_local! { - static SAVE_STATE_CB_CALLED: AtomicBool = AtomicBool::new(false); + static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); } #[tokio::test] - async fn test_should_call_save_state_cb() { - let mut scheduler = scheduler(); - scheduler.set_save_state_query_callback(Box::new(save_state_cb)); - - assert!(scheduler.save_state_query_callback.is_some()); - assert!(scheduler.save_state().await.is_ok()); + async fn test_should_call_report_state_cb() { + let scheduler = scheduler(); + + assert!(scheduler + .report_state(TaskExecutionState::Failed( + 1, + SchedulerError::TaskExecutionFailed("ciao".to_string()) + )) + .await + .is_ok()); - SAVE_STATE_CB_CALLED.with(|called| { - assert!(called.load(std::sync::atomic::Ordering::SeqCst)); + REPORT_STATE_CB_CALLED.with_borrow(|state| { + assert!(matches!( + state.as_ref().unwrap(), + TaskExecutionState::Failed(1, _) + )); }); } - async fn save_state() -> std::result::Result<(), (RejectionCode, String)> { - SAVE_STATE_CB_CALLED.with(|called| { - called.store(true, std::sync::atomic::Ordering::SeqCst); - }); + async fn report_state( + state: TaskExecutionState, + ) -> std::result::Result<(), (RejectionCode, String)> { + if let TaskExecutionState::Failed(id, err) = state { + REPORT_STATE_CB_CALLED.with(|called| { + called.replace(Some(TaskExecutionState::Failed(id, err))); + }); + } + Ok(()) } - fn save_state_cb( + fn report_state_cb( + state: TaskExecutionState, ) -> Pin>>> { - Box::pin(async { save_state().await }) + Box::pin(async { report_state(state).await }) } #[tokio::test] @@ -378,7 +490,7 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let scheduler = scheduler(); + let mut scheduler = scheduler(); let id = random(); scheduler.append_task(SimpleTaskSteps::One { id }.into()); @@ -411,14 +523,10 @@ mod test { #[tokio::test] async fn test_error_cb_not_called_in_case_of_success() { let local = tokio::task::LocalSet::new(); - let called = Arc::new(AtomicBool::new(false)); - let called_t = called.clone(); + local .run_until(async move { let mut scheduler = scheduler(); - scheduler.set_failed_task_callback(move |_, _| { - called_t.store(true, std::sync::atomic::Ordering::SeqCst); - }); let id = random(); scheduler.append_task(SimpleTaskSteps::One { id }.into()); @@ -447,7 +555,9 @@ mod test { }) .await; - assert!(!called.load(std::sync::atomic::Ordering::SeqCst)); + REPORT_STATE_CB_CALLED.with_borrow(|state| { + assert!(state.is_none()); + }); } fn scheduler() -> Scheduler< @@ -457,22 +567,19 @@ mod test { ScheduledTask, std::rc::Rc>>, >, - VirtualMemory, > { let map: StableUnboundedMap< u32, ScheduledTask, std::rc::Rc>>, > = StableUnboundedMap::new(VectorMemory::default()); - let memory_manager: ic_stable_structures::IcMemoryManager< - std::rc::Rc>>, - > = default_ic_memory_manager(); - Scheduler::new(map, &memory_manager, MemoryId::new(1)).unwrap() + Scheduler::new(map, Box::new(report_state_cb)).unwrap() } } mod test_delay { + use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -486,7 +593,9 @@ mod test { use crate::task::TaskOptions; thread_local! { - pub static STATE: Mutex>> = Mutex::new(HashMap::new()) + pub static STATE: Mutex>> = Mutex::new(HashMap::new()); + + static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -522,7 +631,7 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let scheduler = scheduler(); + let mut scheduler = scheduler(); let id = random(); let timestamp: u64 = random(); @@ -564,22 +673,38 @@ mod test { ScheduledTask, std::rc::Rc>>, >, - VirtualMemory, > { let map: StableUnboundedMap< u32, ScheduledTask, std::rc::Rc>>, > = StableUnboundedMap::new(VectorMemory::default()); - let memory_manager: ic_stable_structures::IcMemoryManager< - std::rc::Rc>>, - > = default_ic_memory_manager(); - Scheduler::new(map, &memory_manager, MemoryId::new(1)).unwrap() + + Scheduler::new(map, Box::new(report_state_cb)).unwrap() + } + + async fn report_state( + state: TaskExecutionState, + ) -> std::result::Result<(), (RejectionCode, String)> { + if let TaskExecutionState::Failed(id, err) = state { + REPORT_STATE_CB_CALLED.with(|called| { + called.replace(Some(TaskExecutionState::Failed(id, err))); + }); + } + Ok(()) + } + + fn report_state_cb( + state: TaskExecutionState, + ) -> Pin>>> + { + Box::pin(async { report_state(state).await }) } } mod test_failure_and_retry { + use std::cell::RefCell; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -600,6 +725,8 @@ mod test { thread_local! { static STATE: Mutex> = Mutex::new(HashMap::new()); + + static REPORT_STATE_CB_CALLED: RefCell> = RefCell::new(None); } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -645,7 +772,7 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let scheduler = scheduler(); + let mut scheduler = scheduler(); let id = random(); let fails = 10; let retries = 3; @@ -707,7 +834,7 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let scheduler = scheduler(); + let mut scheduler = scheduler(); let id = random(); let fails = 2; let retries = 4; @@ -750,7 +877,7 @@ mod test { let local = tokio::task::LocalSet::new(); local .run_until(async move { - let scheduler = scheduler(); + let mut scheduler = scheduler(); let id = random(); let fails = 10; let retries = 10; @@ -805,20 +932,12 @@ mod test { #[tokio::test] async fn test_should_call_error_cb() { - use std::sync::atomic::AtomicBool; - let local = tokio::task::LocalSet::new(); - let called = Arc::new(AtomicBool::new(false)); - let called_t = called.clone(); + let id = random(); local .run_until(async move { let mut scheduler = scheduler(); - scheduler.set_failed_task_callback(move |_, _| { - called_t.store(true, std::sync::atomic::Ordering::SeqCst); - }); - - let id = random(); let fails = 10; scheduler.append_task( @@ -836,24 +955,21 @@ mod test { assert_eq!(pending_tasks.len(), 0); }) .await; - assert!(called.load(std::sync::atomic::Ordering::SeqCst)); + REPORT_STATE_CB_CALLED.with_borrow(|state| { + assert!(matches!( + state.as_ref().unwrap(), + TaskExecutionState::Failed(_, _) + )) + }); } #[tokio::test] async fn test_should_not_call_error_cb_if_succeeds_after_retries() { - use std::sync::atomic::AtomicBool; - let local = tokio::task::LocalSet::new(); - let called = Arc::new(AtomicBool::new(false)); - let called_t = called.clone(); local .run_until(async move { let mut scheduler = scheduler(); - scheduler.set_failed_task_callback(move |_, _| { - called_t.store(true, std::sync::atomic::Ordering::SeqCst); - }); - let id = random(); let fails = 2; let retries = 4; @@ -889,25 +1005,20 @@ mod test { }); }) .await; - assert!(!called.load(std::sync::atomic::Ordering::SeqCst)); + + REPORT_STATE_CB_CALLED.with_borrow(|state| { + assert!(state.is_none()); + }); } #[tokio::test] async fn test_should_call_error_only_after_retries() { - use std::sync::atomic::AtomicU8; - let local = tokio::task::LocalSet::new(); - let called = Arc::new(AtomicU8::new(0)); - let called_t = called.clone(); + let id = random(); local .run_until(async move { let mut scheduler = scheduler(); - scheduler.set_failed_task_callback(move |_, _| { - called_t.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - }); - - let id = random(); let fails = 10; let retries = 3; @@ -961,7 +1072,12 @@ mod test { }); }) .await; - assert_eq!(called.load(std::sync::atomic::Ordering::SeqCst), 1); + REPORT_STATE_CB_CALLED.with_borrow(|state| { + assert!(matches!( + state.as_ref().unwrap(), + TaskExecutionState::Failed(_, _) + )); + }); } fn scheduler() -> Scheduler< @@ -971,17 +1087,31 @@ mod test { ScheduledTask, std::rc::Rc>>, >, - VirtualMemory, > { let map: StableUnboundedMap< u32, ScheduledTask, std::rc::Rc>>, > = StableUnboundedMap::new(VectorMemory::default()); - let memory_manager: ic_stable_structures::IcMemoryManager< - std::rc::Rc>>, - > = default_ic_memory_manager(); - Scheduler::new(map, &memory_manager, MemoryId::new(1)).unwrap() + Scheduler::new(map, Box::new(report_state_cb)).unwrap() + } + + async fn report_state( + state: TaskExecutionState, + ) -> std::result::Result<(), (RejectionCode, String)> { + if let TaskExecutionState::Failed(id, err) = state { + REPORT_STATE_CB_CALLED.with(|called| { + called.replace(Some(TaskExecutionState::Failed(id, err))); + }); + } + Ok(()) + } + + fn report_state_cb( + state: TaskExecutionState, + ) -> Pin>>> + { + Box::pin(async { report_state(state).await }) } } } diff --git a/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs b/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs index fcb53847..a38e566a 100644 --- a/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs +++ b/ic-task-scheduler/tests/dummy_scheduler_canister/src/canister.rs @@ -1,75 +1,76 @@ -use std::{cell::RefCell, future::Future, pin::Pin, sync::Arc, time::Duration}; +use std::cell::RefCell; +use std::future::Future; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::time::Duration; use candid::Principal; -use ic_canister::{generate_idl, init, query, Canister, Idl, PreUpdate}; +use ic_canister::{generate_idl, init, post_upgrade, query, update, Canister, Idl, PreUpdate}; use ic_exports::ic_kit::RejectionCode; -use ic_stable_structures::{ - default_ic_memory_manager, stable_structures::DefaultMemoryImpl, IcMemoryManager, MemoryId, - StableUnboundedMap, VirtualMemory, -}; -use ic_task_scheduler::{ - scheduler::{Scheduler, TaskScheduler}, - task::{ScheduledTask, Task, TaskOptions}, - SchedulerError, -}; +use ic_stable_structures::stable_structures::DefaultMemoryImpl; +use ic_stable_structures::{IcMemoryManager, MemoryId, StableUnboundedMap, VirtualMemory}; +use ic_task_scheduler::scheduler::{Scheduler, TaskExecutionState, TaskScheduler}; +use ic_task_scheduler::task::{ScheduledTask, Task, TaskOptions}; +use ic_task_scheduler::SchedulerError; use serde::{Deserialize, Serialize}; -type Storage = StableUnboundedMap, VirtualMemory>; -type PanickingScheduler = Scheduler; +type Storage = StableUnboundedMap, VirtualMemory>; +type PanickingScheduler = Scheduler; const SCHEDULER_STORAGE_MEMORY_ID: MemoryId = MemoryId::new(1); -const TASK_QUEUE_MEMORY_ID: MemoryId = MemoryId::new(2); thread_local! { pub static MEMORY_MANAGER: IcMemoryManager = IcMemoryManager::init(DefaultMemoryImpl::default()); - static SCHEDULER: RefCell> = { - let mut map: Storage = Storage::new(MEMORY_MANAGER.with(|mm| mm.get(SCHEDULER_STORAGE_MEMORY_ID))); + static SCHEDULER: RefCell>> = { + let map: Storage = Storage::new(MEMORY_MANAGER.with(|mm| mm.get(SCHEDULER_STORAGE_MEMORY_ID))); - let memory_manager = default_ic_memory_manager(); - - let mut scheduler = PanickingScheduler::new( + let scheduler = PanickingScheduler::new( map, - &memory_manager, - TASK_QUEUE_MEMORY_ID + Box::new(save_state_cb) ).unwrap(); - scheduler.set_failed_task_callback(move |_, _| { - FAILED_TASK_CALLED.with_borrow_mut(|called| { - *called = true; - }); - }); - scheduler.set_save_state_query_callback(Box::new(save_state_cb)); - scheduler.append_task( - ( - PanicTask::StepOne { id: 1 }, - TaskOptions::new() - .with_max_retries_policy(0) - .with_fixed_backoff_policy(0), - ) - .into(), - ); - RefCell::new(Arc::new(scheduler)) - }; - static SAVE_STATE_CALLED: RefCell = RefCell::new(false); + scheduler.append_task((DummyTask::GoodTask, TaskOptions::new()).into()); + scheduler.append_task((DummyTask::Panicking, TaskOptions::new()).into()); + scheduler.append_task((DummyTask::GoodTask, TaskOptions::new()).into()); + scheduler.append_task((DummyTask::FailTask, TaskOptions::new()).into()); - static FAILED_TASK_CALLED : RefCell = RefCell::new(false); + RefCell::new(Arc::new(Mutex::new(scheduler))) + }; + + static SCHEDULED_STATE_CALLED: RefCell = RefCell::new(false); + static COMPLETED_TASKS: RefCell> = RefCell::new(vec![]); + static FAILED_TASKS: RefCell> = RefCell::new(vec![]); + static PANICKED_TASKS : RefCell> = RefCell::new(vec![]); + static EXECUTING_TASKS : RefCell> = RefCell::new(vec![]); static PRINCIPAL : RefCell = RefCell::new(Principal::anonymous()); } #[derive(Serialize, Deserialize, Debug, Clone)] -pub enum PanicTask { - StepOne { id: u32 }, +pub enum DummyTask { + Panicking, + GoodTask, + FailTask, } -impl Task for PanicTask { +impl Task for DummyTask { fn execute( &self, _task_scheduler: Box>, ) -> Pin>>> { - panic!("PanicTask::execute") + match self { + Self::GoodTask => Box::pin(async move { Ok(()) }), + Self::Panicking => Box::pin(async move { + panic!("PanicTask::execute"); + }), + Self::FailTask => Box::pin(async move { + Err(SchedulerError::TaskExecutionFailed( + "i dunno why".to_string(), + )) + }), + } } } @@ -84,9 +85,7 @@ impl PreUpdate for DummyCanister {} impl DummyCanister { #[init] pub fn init(&self) { - ic_exports::ic_cdk_timers::set_timer(Duration::from_millis(10), || { - ic_cdk::spawn(Self::run_scheduler()) - }); + self.set_timers(); // set principal PRINCIPAL.with_borrow_mut(|principal| { @@ -94,9 +93,20 @@ impl DummyCanister { }); } - #[query] + #[post_upgrade] + pub fn post_upgrade(&self) { + self.set_timers(); + } + + fn set_timers(&self) { + ic_exports::ic_cdk_timers::set_timer_interval(Duration::from_millis(10), || { + ic_cdk::spawn(Self::do_run_scheduler()) + }); + } + + #[update] pub fn save_state(&self) -> bool { - SAVE_STATE_CALLED.with_borrow_mut(|called| { + SCHEDULED_STATE_CALLED.with_borrow_mut(|called| { *called = true; }); @@ -104,18 +114,40 @@ impl DummyCanister { } #[query] - pub fn save_state_called(&self) -> bool { - SAVE_STATE_CALLED.with_borrow(|called| *called) + pub fn scheduled_state_called(&self) -> bool { + SCHEDULED_STATE_CALLED.with_borrow(|called| *called) + } + + #[query] + pub fn panicked_tasks(&self) -> Vec { + PANICKED_TASKS.with_borrow(|tasks| tasks.clone()) + } + + #[query] + pub fn completed_tasks(&self) -> Vec { + COMPLETED_TASKS.with_borrow(|tasks| tasks.clone()) } #[query] - pub fn failed_task_called(&self) -> bool { - FAILED_TASK_CALLED.with_borrow(|called| *called) + pub fn failed_tasks(&self) -> Vec { + FAILED_TASKS.with_borrow(|tasks| tasks.clone()) } - async fn run_scheduler() { - let scheduler = SCHEDULER.with_borrow_mut(|scheduler| scheduler.clone()); - scheduler.run().await.unwrap(); + #[query] + pub fn executed_tasks(&self) -> Vec { + EXECUTING_TASKS.with_borrow(|tasks| tasks.clone()) + } + + #[update] + pub async fn run_scheduler(&self) { + Self::do_run_scheduler().await + } + + async fn do_run_scheduler() { + let scheduler = SCHEDULER.with_borrow(|scheduler| scheduler.clone()); + let mut lock = scheduler.lock().unwrap(); + + lock.run().await.unwrap(); } pub fn idl() -> Idl { @@ -123,11 +155,36 @@ impl DummyCanister { } } -async fn save_state() -> Result<(), (RejectionCode, String)> { +async fn save_state(state: TaskExecutionState) -> Result<(), (RejectionCode, String)> { let canister = PRINCIPAL.with_borrow(|principal| *principal); + match state { + TaskExecutionState::Completed(id) => { + COMPLETED_TASKS.with_borrow_mut(|tasks| { + tasks.push(id); + }); + } + TaskExecutionState::Panicked(id) => { + PANICKED_TASKS.with_borrow_mut(|tasks| { + tasks.push(id); + }); + } + TaskExecutionState::Failed(id, _) => { + FAILED_TASKS.with_borrow_mut(|tasks| { + tasks.push(id); + }); + } + TaskExecutionState::Executing(id) => { + EXECUTING_TASKS.with_borrow_mut(|tasks| { + tasks.push(id); + }); + } + TaskExecutionState::Scheduled => {} + } ic_exports::ic_cdk::call(canister, "save_state", ()).await } -fn save_state_cb() -> Pin>>> { - Box::pin(async { save_state().await }) +fn save_state_cb( + state: TaskExecutionState, +) -> Pin>>> { + Box::pin(async { save_state(state).await }) } diff --git a/ic-task-scheduler/tests/pocket_ic_tests/mod.rs b/ic-task-scheduler/tests/pocket_ic_tests/mod.rs index 93e680c4..7c179845 100644 --- a/ic-task-scheduler/tests/pocket_ic_tests/mod.rs +++ b/ic-task-scheduler/tests/pocket_ic_tests/mod.rs @@ -1,6 +1,8 @@ mod scheduler; mod wasm_utils; +use std::time::Duration; + use candid::{CandidType, Decode, Encode, Principal}; use ic_exports::ic_kit::ic; use ic_exports::pocket_ic; @@ -45,13 +47,27 @@ impl PocketIcTestContext { Decode!(&res, Result).expect("failed to decode item from candid") } - pub async fn save_state_called(&self) -> bool { + pub async fn scheduled_state_called(&self) -> bool { + let args = Encode!(&()).unwrap(); + let res = self + .query_as( + ic::caller(), + self.dummy_scheduler_canister, + "scheduled_state_called", + args, + ) + .await; + + res + } + + pub async fn completed_tasks(&self) -> Vec { let args = Encode!(&()).unwrap(); let res = self .query_as( ic::caller(), self.dummy_scheduler_canister, - "save_state_called", + "completed_tasks", args, ) .await; @@ -59,19 +75,52 @@ impl PocketIcTestContext { res } - pub async fn failed_task_called(&self) -> bool { + pub async fn panicked_tasks(&self) -> Vec { let args = Encode!(&()).unwrap(); let res = self .query_as( ic::caller(), self.dummy_scheduler_canister, - "failed_task_called", + "panicked_tasks", args, ) .await; res } + + pub async fn failed_tasks(&self) -> Vec { + let args = Encode!(&()).unwrap(); + let res = self + .query_as( + ic::caller(), + self.dummy_scheduler_canister, + "failed_tasks", + args, + ) + .await; + + res + } + + pub async fn executed_tasks(&self) -> Vec { + let args = Encode!(&()).unwrap(); + let res = self + .query_as( + ic::caller(), + self.dummy_scheduler_canister, + "executed_tasks", + args, + ) + .await; + + res + } + + pub async fn run_scheduler(&self) { + self.client.advance_time(Duration::from_millis(5000)).await; + self.client.tick().await; + } } async fn deploy_dummy_scheduler_canister() -> anyhow::Result { diff --git a/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs b/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs index f717f5f3..1b5b3585 100644 --- a/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs +++ b/ic-task-scheduler/tests/pocket_ic_tests/scheduler.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use candid::Principal; use ic_kit::mock_principals::alice; @@ -19,10 +17,13 @@ async fn test_should_remove_panicking_task() { let test_ctx = deploy_dummy_scheduler_canister().await.unwrap(); CANISTER.with_borrow_mut(|principal| *principal = test_ctx.dummy_scheduler_canister); - // set error callback - std::thread::sleep(Duration::from_millis(500)); + for _ in 0..10 { + test_ctx.run_scheduler().await; + } - // check states - assert!(test_ctx.save_state_called().await); - assert!(test_ctx.failed_task_called().await); + assert!(test_ctx.scheduled_state_called().await); + assert_eq!(test_ctx.executed_tasks().await, vec![0, 2, 1, 3, 3, 2, 1]); + assert_eq!(test_ctx.panicked_tasks().await, vec![1]); + assert_eq!(test_ctx.completed_tasks().await, vec![0, 2]); + assert_eq!(test_ctx.failed_tasks().await, vec![3]); }