Skip to content

Commit

Permalink
feat(batcher): add abort to the proposal manager
Browse files Browse the repository at this point in the history
  • Loading branch information
dafnamatsry committed Nov 14, 2024
1 parent 71da94a commit 1054fbc
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 49 deletions.
1 change: 1 addition & 0 deletions crates/starknet_batcher/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ impl From<GetProposalResultError> for BatcherError {
GetProposalResultError::ProposalDoesNotExist { proposal_id } => {
BatcherError::ExecutedProposalNotFound { proposal_id }
}
GetProposalResultError::Aborted => BatcherError::ProposalAborted,
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions crates/starknet_batcher/src/batcher_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ trait ProposalManagerTraitWrapper: Send + Sync {
&self,
proposal_id: ProposalId,
) -> BoxFuture<'_, ProposalResult<ProposalCommitment>>;

fn wrap_abort_proposal(&mut self, proposal_id: ProposalId) -> BoxFuture<'_, ()>;
}

#[async_trait]
Expand Down Expand Up @@ -307,6 +309,10 @@ impl<T: ProposalManagerTraitWrapper> ProposalManagerTrait for T {
) -> ProposalResult<ProposalCommitment> {
self.wrap_executed_proposal_commitment(proposal_id).await
}

async fn abort_proposal(&mut self, proposal_id: ProposalId) {
self.wrap_abort_proposal(proposal_id).await
}
}

fn test_tx_hashes(range: std::ops::Range<u128>) -> HashSet<TransactionHash> {
Expand Down
87 changes: 52 additions & 35 deletions crates/starknet_batcher/src/proposal_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub enum GetProposalResultError {
BlockBuilderError(Arc<BlockBuilderError>),
#[error("Proposal with id {proposal_id} does not exist.")]
ProposalDoesNotExist { proposal_id: ProposalId },
#[error("Proposal was aborted")]
Aborted,
}

pub enum ProposalStatus {
Expand Down Expand Up @@ -98,6 +100,15 @@ pub trait ProposalManagerTrait: Send + Sync {
&mut self,
proposal_id: ProposalId,
) -> ProposalResult<ProposalCommitment>;

#[allow(dead_code)]
async fn abort_proposal(&mut self, proposal_id: ProposalId);
}

// Represents a spawned task of building new block proposal.
struct ProposalTask {
abort_signal_sender: tokio::sync::oneshot::Sender<()>,
join_handle: tokio::task::JoinHandle<()>,
}

/// Main struct for handling block proposals.
Expand All @@ -110,17 +121,18 @@ pub trait ProposalManagerTrait: Send + Sync {
pub(crate) struct ProposalManager {
storage_reader: Arc<dyn BatcherStorageReaderTrait>,
active_height: Option<BlockNumber>,
/// The block proposal that is currently being proposed, if any.

/// The block proposal that is currently being built, if any.
/// At any given time, there can be only one proposal being actively executed (either proposed
/// or validated).
active_proposal: Arc<Mutex<Option<ProposalId>>>,
active_proposal_handle: Option<ActiveTaskHandle>,
active_proposal_task: Option<ProposalTask>,

// Use a factory object, to be able to mock BlockBuilder in tests.
block_builder_factory: Arc<dyn BlockBuilderFactoryTrait + Send + Sync>,
executed_proposals: Arc<Mutex<HashMap<ProposalId, ProposalResult<ProposalOutput>>>>,
}

type ActiveTaskHandle = tokio::task::JoinHandle<()>;
pub type ProposalResult<T> = Result<T, GetProposalResultError>;

#[derive(Debug, PartialEq)]
Expand Down Expand Up @@ -177,7 +189,7 @@ impl ProposalManagerTrait for ProposalManager {
info!("Starting generation of a new proposal with id {}.", proposal_id);

// Create the block builder, and a channel to allow aborting the block building task.
let (_abort_signal_sender, abort_signal_receiver) = tokio::sync::oneshot::channel();
let (abort_signal_sender, abort_signal_receiver) = tokio::sync::oneshot::channel();
let height = self.active_height.expect("No active height.");

let block_builder = self.block_builder_factory.create_block_builder(
Expand All @@ -188,7 +200,8 @@ impl ProposalManagerTrait for ProposalManager {
abort_signal_receiver,
)?;

self.spawn_build_block_task(proposal_id, block_builder).await;
let join_handle = self.spawn_build_block_task(proposal_id, block_builder).await;
self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle });

Ok(())
}
Expand Down Expand Up @@ -223,7 +236,9 @@ impl ProposalManagerTrait for ProposalManager {
&mut self,
proposal_id: ProposalId,
) -> ProposalResult<ProposalCommitment> {
self.await_proposal_completion(proposal_id).await;
if self.active_proposal.lock().await.is_some_and(|id| id == proposal_id) {
self.await_active_proposal().await;
}
let proposals = self.executed_proposals.lock().await;
let output = proposals
.get(&proposal_id)
Expand All @@ -233,6 +248,16 @@ impl ProposalManagerTrait for ProposalManager {
Err(e) => Err(e.clone()),
}
}

async fn abort_proposal(&mut self, proposal_id: ProposalId) {
if self.active_proposal.lock().await.is_some_and(|id| id == proposal_id) {
self.abort_active_proposal().await;
self.executed_proposals
.lock()
.await
.insert(proposal_id, Err(GetProposalResultError::Aborted));
}
}
}

impl ProposalManager {
Expand All @@ -244,7 +269,7 @@ impl ProposalManager {
storage_reader,
active_proposal: Arc::new(Mutex::new(None)),
block_builder_factory,
active_proposal_handle: None,
active_proposal_task: None,
active_height: None,
executed_proposals: Arc::new(Mutex::new(HashMap::new())),
}
Expand All @@ -254,30 +279,28 @@ impl ProposalManager {
&mut self,
proposal_id: ProposalId,
mut block_builder: Box<dyn BlockBuilderTrait>,
) {
) -> tokio::task::JoinHandle<()> {
let active_proposal = self.active_proposal.clone();
let executed_proposals = self.executed_proposals.clone();

self.active_proposal_handle = Some(tokio::spawn(
tokio::spawn(
async move {
let result = block_builder
.build_block()
.await
.map(ProposalOutput::from)
.map_err(|e| GetProposalResultError::BlockBuilderError(Arc::new(e)));

executed_proposals.lock().await.insert(proposal_id, result);

// The proposal is done, clear the active proposal.
// Keep the proposal result only if it is the same as the active proposal.
let mut active_proposal = active_proposal.lock().await;
if let Some(current_active_proposal_id) = *active_proposal {
if current_active_proposal_id == proposal_id {
active_proposal.take();
}
if active_proposal.is_some_and(|id| id == proposal_id) {
active_proposal.take();
executed_proposals.lock().await.insert(proposal_id, result);
}
}
.in_current_span(),
));
)
}

async fn reset_active_height(&mut self) {
Expand Down Expand Up @@ -314,30 +337,24 @@ impl ProposalManager {
Ok(())
}

// This function assumes there are not requests processed in parallel by the batcher, otherwise
// there is a race conditon between creating the active_proposal_handle and awaiting on it.
pub async fn await_proposal_completion(&mut self, proposal_id: ProposalId) {
if self.active_proposal.lock().await.as_ref() == Some(&proposal_id) {
let _ = self
.active_proposal_handle
.take()
.expect("Active proposal handle should exist.")
.await;
}
}

// A helper function for testing purposes (to be able to await the active proposal).
// Returns true if there was an active porposal, and false otherwise.
// TODO: Consider making the tests a nested module to allow them to access private members.
// TODO(yael 5/1/2024): use wait_for_proposal_completion instead of this function.
#[cfg(test)]
// Awaits the active proposal.
// Returns true if there was an active proposal, and false otherwise.
pub async fn await_active_proposal(&mut self) -> bool {
if let Some(handle) = self.active_proposal_handle.take() {
handle.await.unwrap();
if let Some(proposal_task) = self.active_proposal_task.take() {
proposal_task.join_handle.await.ok();
return true;
}
false
}

// Ends the current active proposal.
// This call is non-blocking.
async fn abort_active_proposal(&mut self) {
self.active_proposal.lock().await.take();
if let Some(proposal_task) = self.active_proposal_task.take() {
proposal_task.abort_signal_sender.send(()).ok();
}
}
}

impl From<BlockExecutionArtifacts> for ProposalOutput {
Expand Down
47 changes: 33 additions & 14 deletions crates/starknet_batcher/src/proposal_manager_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ fn proposal_deadline() -> tokio::time::Instant {
tokio::time::Instant::now() + BLOCK_GENERATION_TIMEOUT
}

async fn build_and_await_block_proposal(
async fn build_block_proposal(
proposal_manager: &mut ProposalManager,
tx_provider: ProposeTransactionProvider,
proposal_id: ProposalId,
Expand All @@ -121,7 +121,14 @@ async fn build_and_await_block_proposal(
.build_block_proposal(proposal_id, None, proposal_deadline(), output_sender, tx_provider)
.await
.unwrap();
}

async fn build_and_await_block_proposal(
proposal_manager: &mut ProposalManager,
tx_provider: ProposeTransactionProvider,
proposal_id: ProposalId,
) {
build_block_proposal(proposal_manager, tx_provider, proposal_id).await;
assert!(proposal_manager.await_active_proposal().await);
}

Expand Down Expand Up @@ -280,33 +287,45 @@ async fn test_take_proposal_result_no_active_proposal(mut mock_dependencies: Moc
);
}

#[rstest]
#[tokio::test]
async fn test_abort_active_proposal(mut mock_dependencies: MockDependencies) {
mock_dependencies.expect_long_build_block(1);

let tx_provider = propose_tx_provider(&mock_dependencies);
let mut proposal_manager = init_proposal_manager(mock_dependencies);

proposal_manager.start_height(INITIAL_HEIGHT).await.unwrap();

// Start a new proposal, which will remain active.
build_block_proposal(&mut proposal_manager, tx_provider, ProposalId(0)).await;

proposal_manager.abort_proposal(ProposalId(0)).await;

assert_matches!(
proposal_manager.take_proposal_result(ProposalId(0)).await,
Err(GetProposalResultError::Aborted)
);

// Make sure there is no active proposal.
assert!(!proposal_manager.await_active_proposal().await);
}

#[rstest]
#[tokio::test]
async fn test_abort_and_restart_height(mut mock_dependencies: MockDependencies) {
mock_dependencies.expect_build_block(1);
mock_dependencies.expect_long_build_block(1);

// Start a new height and create a proposal.
let (output_tx_sender, _receiver) = output_streaming();
let tx_provider_0 = propose_tx_provider(&mock_dependencies);
let tx_provider_1 = propose_tx_provider(&mock_dependencies);
let mut proposal_manager = init_proposal_manager(mock_dependencies);
proposal_manager.start_height(INITIAL_HEIGHT).await.unwrap();
build_and_await_block_proposal(&mut proposal_manager, tx_provider_0, ProposalId(0)).await;

// Start a new proposal, which will remain active.
assert!(
proposal_manager
.build_block_proposal(
ProposalId(1),
None,
proposal_deadline(),
output_tx_sender,
tx_provider_1
)
.await
.is_ok()
);
build_block_proposal(&mut proposal_manager, tx_provider_1, ProposalId(1)).await;

// Restart the same height. This should abort and delete all existing proposals.
assert!(proposal_manager.start_height(INITIAL_HEIGHT).await.is_ok());
Expand Down
2 changes: 2 additions & 0 deletions crates/starknet_batcher_types/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub enum BatcherError {
ProposalAlreadyFinished { proposal_id: ProposalId },
#[error("Proposal failed.")]
ProposalFailed,
#[error("Proposal aborted.")]
ProposalAborted,
#[error("Proposal with ID {proposal_id} not found.")]
ProposalNotFound { proposal_id: ProposalId },
#[error(
Expand Down

0 comments on commit 1054fbc

Please sign in to comment.