diff --git a/crates/starknet_batcher/src/batcher_test.rs b/crates/starknet_batcher/src/batcher_test.rs index efdb8117eb..30cb2e8c91 100644 --- a/crates/starknet_batcher/src/batcher_test.rs +++ b/crates/starknet_batcher/src/batcher_test.rs @@ -264,6 +264,8 @@ trait ProposalManagerTraitWrapper: Send + Sync { &self, proposal_id: ProposalId, ) -> BoxFuture<'_, ProposalResult>; + + fn wrap_abort_proposal(&mut self, proposal_id: ProposalId) -> BoxFuture<'_, ()>; } #[async_trait] @@ -307,6 +309,10 @@ impl ProposalManagerTrait for T { ) -> ProposalResult { self.wrap_executed_proposal_commitment(proposal_id).await } + + async fn abort_proposal(&mut self, proposal_id: ProposalId) { + self.wrap_abort_proposal(proposal_id).await + } } fn test_tx_hashes(range: std::ops::Range) -> HashSet { diff --git a/crates/starknet_batcher/src/proposal_manager.rs b/crates/starknet_batcher/src/proposal_manager.rs index 399243358d..e7896c268b 100644 --- a/crates/starknet_batcher/src/proposal_manager.rs +++ b/crates/starknet_batcher/src/proposal_manager.rs @@ -98,6 +98,15 @@ pub trait ProposalManagerTrait: Send + Sync { &mut self, proposal_id: ProposalId, ) -> ProposalResult; + + #[allow(dead_code)] + async fn abort_proposal(&mut self, proposal_id: ProposalId); +} + +// Represents a spawned task of building new block proposal. +struct ProposalTask { + abort_signal_sender: tokio::sync::oneshot::Sender<()>, + join_handle: tokio::task::JoinHandle<()>, } /// Main struct for handling block proposals. @@ -110,17 +119,18 @@ pub trait ProposalManagerTrait: Send + Sync { pub(crate) struct ProposalManager { storage_reader: Arc, active_height: Option, - /// The block proposal that is currently being proposed, if any. + + /// The block proposal that is currently being built, if any. /// At any given time, there can be only one proposal being actively executed (either proposed /// or validated). active_proposal: Arc>>, - active_proposal_handle: Option, + active_proposal_task: Option, + // Use a factory object, to be able to mock BlockBuilder in tests. block_builder_factory: Arc, executed_proposals: Arc>>>, } -type ActiveTaskHandle = tokio::task::JoinHandle<()>; pub type ProposalResult = Result; #[derive(Debug, PartialEq)] @@ -177,7 +187,7 @@ impl ProposalManagerTrait for ProposalManager { info!("Starting generation of a new proposal with id {}.", proposal_id); // Create the block builder, and a channel to allow aborting the block building task. - let (_abort_signal_sender, abort_signal_receiver) = tokio::sync::oneshot::channel(); + let (abort_signal_sender, abort_signal_receiver) = tokio::sync::oneshot::channel(); let height = self.active_height.expect("No active height."); let block_builder = self.block_builder_factory.create_block_builder( @@ -188,7 +198,8 @@ impl ProposalManagerTrait for ProposalManager { abort_signal_receiver, )?; - self.spawn_build_block_task(proposal_id, block_builder).await; + let join_handle = self.spawn_build_block_task(proposal_id, block_builder).await; + self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle }); Ok(()) } @@ -223,7 +234,9 @@ impl ProposalManagerTrait for ProposalManager { &mut self, proposal_id: ProposalId, ) -> ProposalResult { - self.await_proposal_completion(proposal_id).await; + if self.active_proposal.lock().await.is_some_and(|id| id == proposal_id) { + self.await_active_proposal().await; + } let proposals = self.executed_proposals.lock().await; let output = proposals .get(&proposal_id) @@ -233,6 +246,12 @@ impl ProposalManagerTrait for ProposalManager { Err(e) => Err(e.clone()), } } + + async fn abort_proposal(&mut self, proposal_id: ProposalId) { + if self.active_proposal.lock().await.is_some_and(|id| id == proposal_id) { + self.abort_active_proposal().await; + } + } } impl ProposalManager { @@ -244,7 +263,7 @@ impl ProposalManager { storage_reader, active_proposal: Arc::new(Mutex::new(None)), block_builder_factory, - active_proposal_handle: None, + active_proposal_task: None, active_height: None, executed_proposals: Arc::new(Mutex::new(HashMap::new())), } @@ -254,11 +273,11 @@ impl ProposalManager { &mut self, proposal_id: ProposalId, mut block_builder: Box, - ) { + ) -> tokio::task::JoinHandle<()> { let active_proposal = self.active_proposal.clone(); let executed_proposals = self.executed_proposals.clone(); - self.active_proposal_handle = Some(tokio::spawn( + tokio::spawn( async move { let result = block_builder .build_block() @@ -270,14 +289,12 @@ impl ProposalManager { // The proposal is done, clear the active proposal. let mut active_proposal = active_proposal.lock().await; - if let Some(current_active_proposal_id) = *active_proposal { - if current_active_proposal_id == proposal_id { - active_proposal.take(); - } + if active_proposal.is_some_and(|id| id == proposal_id) { + active_proposal.take(); } } .in_current_span(), - )); + ) } async fn reset_active_height(&mut self) { @@ -314,30 +331,24 @@ impl ProposalManager { Ok(()) } - // This function assumes there are not requests processed in parallel by the batcher, otherwise - // there is a race conditon between creating the active_proposal_handle and awaiting on it. - pub async fn await_proposal_completion(&mut self, proposal_id: ProposalId) { - if self.active_proposal.lock().await.as_ref() == Some(&proposal_id) { - let _ = self - .active_proposal_handle - .take() - .expect("Active proposal handle should exist.") - .await; - } - } - - // A helper function for testing purposes (to be able to await the active proposal). - // Returns true if there was an active porposal, and false otherwise. - // TODO: Consider making the tests a nested module to allow them to access private members. - // TODO(yael 5/1/2024): use wait_for_proposal_completion instead of this function. - #[cfg(test)] + // Awaits the active proposal. + // Returns true if there was an active proposal, and false otherwise. pub async fn await_active_proposal(&mut self) -> bool { - if let Some(handle) = self.active_proposal_handle.take() { - handle.await.unwrap(); + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.join_handle.await.ok(); return true; } false } + + // Ends the current active proposal. + // This call is non-blocking. + async fn abort_active_proposal(&mut self) { + self.active_proposal.lock().await.take(); + if let Some(proposal_task) = self.active_proposal_task.take() { + proposal_task.abort_signal_sender.send(()).ok(); + } + } } impl From for ProposalOutput { diff --git a/crates/starknet_batcher/src/proposal_manager_test.rs b/crates/starknet_batcher/src/proposal_manager_test.rs index 58a3d744ca..9c7ada4197 100644 --- a/crates/starknet_batcher/src/proposal_manager_test.rs +++ b/crates/starknet_batcher/src/proposal_manager_test.rs @@ -111,7 +111,7 @@ fn proposal_deadline() -> tokio::time::Instant { tokio::time::Instant::now() + BLOCK_GENERATION_TIMEOUT } -async fn build_and_await_block_proposal( +async fn build_block_proposal( proposal_manager: &mut ProposalManager, tx_provider: ProposeTransactionProvider, proposal_id: ProposalId, @@ -121,7 +121,14 @@ async fn build_and_await_block_proposal( .build_block_proposal(proposal_id, None, proposal_deadline(), output_sender, tx_provider) .await .unwrap(); +} +async fn build_and_await_block_proposal( + proposal_manager: &mut ProposalManager, + tx_provider: ProposeTransactionProvider, + proposal_id: ProposalId, +) { + build_block_proposal(proposal_manager, tx_provider, proposal_id).await; assert!(proposal_manager.await_active_proposal().await); } @@ -280,6 +287,25 @@ async fn test_take_proposal_result_no_active_proposal(mut mock_dependencies: Moc ); } +#[rstest] +#[tokio::test] +async fn test_abort_active_proposal(mut mock_dependencies: MockDependencies) { + mock_dependencies.expect_long_build_block(1); + + let tx_provider = propose_tx_provider(&mock_dependencies); + let mut proposal_manager = init_proposal_manager(mock_dependencies); + + proposal_manager.start_height(INITIAL_HEIGHT).await.unwrap(); + + // Start a new proposal, which will remain active. + build_block_proposal(&mut proposal_manager, tx_provider, ProposalId(0)).await; + + proposal_manager.abort_proposal(ProposalId(0)).await; + + // Make sure there is no active proposal. + assert!(!proposal_manager.await_active_proposal().await); +} + #[rstest] #[tokio::test] async fn test_abort_and_restart_height(mut mock_dependencies: MockDependencies) { @@ -287,7 +313,6 @@ async fn test_abort_and_restart_height(mut mock_dependencies: MockDependencies) mock_dependencies.expect_long_build_block(1); // Start a new height and create a proposal. - let (output_tx_sender, _receiver) = output_streaming(); let tx_provider_0 = propose_tx_provider(&mock_dependencies); let tx_provider_1 = propose_tx_provider(&mock_dependencies); let mut proposal_manager = init_proposal_manager(mock_dependencies); @@ -295,18 +320,7 @@ async fn test_abort_and_restart_height(mut mock_dependencies: MockDependencies) build_and_await_block_proposal(&mut proposal_manager, tx_provider_0, ProposalId(0)).await; // Start a new proposal, which will remain active. - assert!( - proposal_manager - .build_block_proposal( - ProposalId(1), - None, - proposal_deadline(), - output_tx_sender, - tx_provider_1 - ) - .await - .is_ok() - ); + build_block_proposal(&mut proposal_manager, tx_provider_1, ProposalId(1)).await; // Restart the same height. This should abort and delete all existing proposals. assert!(proposal_manager.start_height(INITIAL_HEIGHT).await.is_ok());