diff --git a/src/job_manager.rs b/src/job_manager.rs index 3d08f9aa..581c5455 100644 --- a/src/job_manager.rs +++ b/src/job_manager.rs @@ -194,7 +194,7 @@ impl ProtocolWorkManager { pub fn can_submit_more_tasks(&self) -> bool { let lock = self.inner.read(); lock.enqueued_tasks.len() + lock.active_tasks.len() - <= *self.max_enqueued_tasks + *self.max_tasks + < *self.max_enqueued_tasks + *self.max_tasks } // Only relevant for keygen @@ -1012,4 +1012,66 @@ mod tests { // The message should be enqueued for future use assert_eq!(delivery_type, DeliveryType::EnqueuedMessage); } + + #[tokio::test] + async fn test_can_submit_more_tasks() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + + assert!(work_manager.can_submit_more_tasks()); + work_manager + .push_task([1; 32], false, remote, task) + .unwrap(); + assert!(!work_manager.can_submit_more_tasks()); + } + + #[tokio::test] + async fn test_multiple_messages_single_task() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 10, 10, PollMethod::Manual); + let (remote, task, mut rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([1; 32], true, remote.clone(), task) + .unwrap(); + + let message = TestMessage { + message: "test".to_string(), + associated_block_id: 0, + associated_session_id: 1, + associated_ssid: 1, + }; + + for _ in 0..10 { + assert_ne!( + DeliveryType::EnqueuedMessage, + work_manager + .deliver_message(message.clone(), [1; 32]) + .unwrap() + ); + } + + for _ in 0..10 { + let next_message = rx.recv().await.unwrap(); + assert_eq!(next_message, message); + } + } + + #[tokio::test] + async fn test_task_not_started() { + let work_manager = ProtocolWorkManager::new(TestWorkManager, 1, 0, PollMethod::Manual); + let (remote, task, _rx) = generate_async_protocol(1, 1, 0); + + work_manager + .push_task([1; 32], false, remote, task) + .unwrap(); // task is enqueued but not started + assert!(!work_manager + .get_active_sessions_metadata(0) + .contains(&JobMetadata { + session_id: 1, + is_stalled: false, + is_finished: false, + has_started: false, + is_active: false, + })); + } }