Skip to content

Commit

Permalink
feat(resharding) - congestion info computation (#12581)
Browse files Browse the repository at this point in the history
Implementing the congestion info computation based on the parent
congestion info and the receipt groups info from the parent trie. The
seed used for the allowed shard is a bit hacky - please let me know if
this makes sense.

I added assertion checking that the buffered gas in congestion info is
zero iff the buffers are empty. With this in place the
`test_resharding_v3_buffered_receipts_towards_splitted_shard` tests fail
without the updated congestion info and pass with it in place.
  • Loading branch information
wacban authored Jan 7, 2025
1 parent 9d535a8 commit e3be29b
Show file tree
Hide file tree
Showing 16 changed files with 301 additions and 71 deletions.
5 changes: 4 additions & 1 deletion chain/chain/src/flat_storage_resharder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,10 @@ fn shard_split_handle_key_value(
| col::BANDWIDTH_SCHEDULER_STATE => {
copy_kv_to_all_children(&split_params, key, value, store_update)
}
col::BUFFERED_RECEIPT_INDICES | col::BUFFERED_RECEIPT => {
col::BUFFERED_RECEIPT_INDICES
| col::BUFFERED_RECEIPT
| col::BUFFERED_RECEIPT_GROUPS_QUEUE_DATA
| col::BUFFERED_RECEIPT_GROUPS_QUEUE_ITEM => {
copy_kv_to_left_child(&split_params, key, value, store_update)
}
_ => unreachable!("key: {:?} should not appear in flat store!", key),
Expand Down
143 changes: 136 additions & 7 deletions chain/chain/src/resharding/manager.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cell::RefCell;
use std::io;
use std::sync::Arc;

Expand All @@ -6,11 +7,13 @@ use super::types::ReshardingSender;
use crate::flat_storage_resharder::{FlatStorageResharder, FlatStorageResharderController};
use crate::types::RuntimeAdapter;
use crate::ChainStoreUpdate;
use itertools::Itertools;
use near_chain_configs::{MutableConfigValue, ReshardingConfig, ReshardingHandle};
use near_chain_primitives::Error;
use near_epoch_manager::EpochManagerAdapter;
use near_primitives::block::Block;
use near_primitives::challenge::PartialState;
use near_primitives::congestion_info::CongestionInfo;
use near_primitives::hash::CryptoHash;
use near_primitives::shard_layout::{get_block_shard_uid, ShardLayout};
use near_primitives::types::chunk_extra::ChunkExtra;
Expand All @@ -19,8 +22,9 @@ use near_store::adapter::{StoreAdapter, StoreUpdateAdapter};
use near_store::flat::BlockInfo;
use near_store::trie::mem::mem_trie_update::TrackingMode;
use near_store::trie::ops::resharding::RetainMode;
use near_store::trie::outgoing_metadata::ReceiptGroupsQueue;
use near_store::trie::TrieRecorder;
use near_store::{DBCol, ShardTries, ShardUId, Store};
use near_store::{DBCol, ShardTries, ShardUId, Store, TrieAccess};

pub struct ReshardingManager {
store: Store,
Expand Down Expand Up @@ -187,7 +191,7 @@ impl ReshardingManager {
// blocks, the second finalization will crash.
tries.freeze_mem_tries(parent_shard_uid, split_shard_event.children_shards())?;

let chunk_extra = self.get_chunk_extra(block_hash, &parent_shard_uid)?;
let parent_chunk_extra = self.get_chunk_extra(block_hash, &parent_shard_uid)?;
let boundary_account = split_shard_event.boundary_account;

let mut trie_store_update = self.store.store_update();
Expand All @@ -214,20 +218,46 @@ impl ReshardingManager {
let mut mem_tries = mem_tries.write().unwrap();
let mut trie_recorder = TrieRecorder::new();
let mode = TrackingMode::RefcountsAndAccesses(&mut trie_recorder);
let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), mode)?;
let mem_trie_update = mem_tries.update(*parent_chunk_extra.state_root(), mode)?;

let trie_changes = mem_trie_update.retain_split_shard(&boundary_account, retain_mode);
let partial_storage = trie_recorder.recorded_storage();
let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap();
let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes);
drop(mem_tries);

// Get the congestion info for the child.
let parent_epoch_id = block.header().epoch_id();
let parent_shard_layout = self.epoch_manager.get_shard_layout(&parent_epoch_id)?;
let parent_state_root = *parent_chunk_extra.state_root();
let parent_trie = tries.get_trie_for_shard(parent_shard_uid, parent_state_root);

let trie_recorder = RefCell::new(trie_recorder);
let parent_trie = parent_trie.recording_reads_with_recorder(trie_recorder);

let child_epoch_id = self.epoch_manager.get_next_epoch_id(block.hash())?;
let child_shard_layout = self.epoch_manager.get_shard_layout(&child_epoch_id)?;
let child_congestion_info = Self::get_child_congestion_info(
&parent_trie,
&parent_shard_layout,
&parent_chunk_extra,
&child_shard_layout,
new_shard_uid,
retain_mode,
)?;

let trie_recorder = parent_trie.take_recorder().unwrap();
let partial_storage = trie_recorder.borrow_mut().recorded_storage();
let partial_state_len = match &partial_storage.nodes {
PartialState::TrieValues(values) => values.len(),
};
let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap();
let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes);

// TODO(resharding): set all fields of `ChunkExtra`. Consider stronger
// typing. Clarify where it should happen when `State` and
// `FlatState` update is implemented.
let mut child_chunk_extra = ChunkExtra::clone(&chunk_extra);
let mut child_chunk_extra = ChunkExtra::clone(&parent_chunk_extra);
*child_chunk_extra.state_root_mut() = new_state_root;
*child_chunk_extra.congestion_info_mut().expect("The congestion info must exist!") =
child_congestion_info;

chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, child_chunk_extra);
chain_store_update.save_state_transition_data(
Expand Down Expand Up @@ -260,6 +290,105 @@ impl ReshardingManager {
Ok(())
}

fn get_child_congestion_info(
parent_trie: &dyn TrieAccess,
parent_shard_layout: &ShardLayout,
parent_chunk_extra: &Arc<ChunkExtra>,
child_shard_layout: &ShardLayout,
child_shard_uid: ShardUId,
retain_mode: RetainMode,
) -> Result<CongestionInfo, Error> {
let parent_congestion_info =
parent_chunk_extra.congestion_info().expect("The congestion info must exist!");

// Get the congestion info based on the parent shard.
let mut child_congestion_info = Self::get_child_congestion_info_not_finalized(
parent_trie,
&parent_shard_layout,
parent_congestion_info,
retain_mode,
)?;

// Set the allowed shard based on the child shard.
Self::finalize_allowed_shard(
&child_shard_layout,
child_shard_uid,
&mut child_congestion_info,
)?;

Ok(child_congestion_info)
}

// Get the congestion info for the child shard. The congestion info can be
// inferred efficiently from the combination of the parent shard's
// congestion info and the receipt group metadata, that is available in the
// parent shard's trie.
pub fn get_child_congestion_info_not_finalized(
parent_trie: &dyn TrieAccess,
parent_shard_layout: &ShardLayout,
parent_congestion_info: CongestionInfo,
retain_mode: RetainMode,
) -> Result<CongestionInfo, Error> {
tracing::debug!(target: "resharding", "Getting child congestion info.");
// The left child contains all the delayed and buffered receipts from the
// parent so it should have identical congestion info.
if retain_mode == RetainMode::Left {
return Ok(parent_congestion_info);
}

// The right child contains all the delayed receipts from the parent but it
// has no buffered receipts. It's info needs to be computed by subtracting
// the parent's buffered receipts from the parent's congestion info.
let mut congestion_info = parent_congestion_info;
for shard_id in parent_shard_layout.shard_ids() {
let receipt_groups = ReceiptGroupsQueue::load(parent_trie, shard_id)?;
let Some(receipt_groups) = receipt_groups else {
continue;
};

let bytes = receipt_groups.total_size();
let gas = receipt_groups.total_gas();

congestion_info
.remove_buffered_receipt_gas(gas)
.expect("Buffered gas must not exceed congestion info buffered gas");
congestion_info
.remove_receipt_bytes(bytes)
.expect("Buffered size must not exceed congestion info buffered size");
}

// The right child does not inherit any buffered receipts. The
// congestion info must match this invariant.
assert_eq!(congestion_info.buffered_receipts_gas(), 0);

tracing::debug!(target: "resharding", "Getting child congestion info done.");
Ok(congestion_info)
}

pub fn finalize_allowed_shard(
child_shard_layout: &ShardLayout,
child_shard_uid: ShardUId,
congestion_info: &mut CongestionInfo,
) -> Result<(), Error> {
let all_shards = child_shard_layout.shard_ids().collect_vec();
let own_shard = child_shard_uid.shard_id();
let own_shard_index = child_shard_layout
.get_shard_index(own_shard)?
.try_into()
.expect("ShardIndex must fit in u64");
// Please note that the congestion seed used during resharding is
// different than the one used during normal operation. In runtime the
// seed is set to the sum of shard index and block height. The block
// height isn't easily available on all call sites which is why the
// simplified seed is used. This is valid because it's deterministic and
// resharding is a very rare event. However in a perfect world it should
// be the same.
// TODO - Use proper congestion control seed during resharding.
let congestion_seed = own_shard_index;
congestion_info.finalize_allowed_shard(own_shard, &all_shards, congestion_seed);
Ok(())
}

// TODO(store): Use proper store interface
fn get_chunk_extra(
&self,
Expand Down
6 changes: 3 additions & 3 deletions chain/chain/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ impl RuntimeAdapter for NightshadeRuntime {
if ProtocolFeature::StatelessValidation.enabled(next_protocol_version)
|| cfg!(feature = "shadow_chunk_validation")
{
trie = trie.recording_reads();
trie = trie.recording_reads_new_recorder();
}
let mut state_update = TrieUpdate::new(trie);

Expand Down Expand Up @@ -799,7 +799,7 @@ impl RuntimeAdapter for NightshadeRuntime {
}
}
}
debug!(target: "runtime", "Transaction filtering results {} valid out of {} pulled from the pool", result.transactions.len(), num_checked_transactions);
debug!(target: "runtime", limited_by=?result.limited_by, "Transaction filtering results {} valid out of {} pulled from the pool", result.transactions.len(), num_checked_transactions);
let shard_label = shard_id.to_string();
metrics::PREPARE_TX_SIZE.with_label_values(&[&shard_label]).observe(total_size as f64);
metrics::PREPARE_TX_REJECTED
Expand Down Expand Up @@ -882,7 +882,7 @@ impl RuntimeAdapter for NightshadeRuntime {
if ProtocolFeature::StatelessValidation.enabled(next_protocol_version)
|| cfg!(feature = "shadow_chunk_validation")
{
trie = trie.recording_reads();
trie = trie.recording_reads_new_recorder();
}

match self.process_state_update(
Expand Down
38 changes: 32 additions & 6 deletions chain/chain/src/stateless_validation/chunk_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::chain::{
};
use crate::rayon_spawner::RayonAsyncComputationSpawner;
use crate::resharding::event_type::ReshardingEventType;
use crate::resharding::manager::ReshardingManager;
use crate::sharding::shuffle_receipt_proofs;
use crate::stateless_validation::processing_tracker::ProcessingDoneTracker;
use crate::store::filter_incoming_receipts_for_shard;
Expand Down Expand Up @@ -96,6 +97,7 @@ pub fn validate_prepared_transactions(
) -> Result<PreparedTransactions, Error> {
let parent_block = chain.chain_store().get_block(chunk_header.prev_block_hash())?;
let last_chunk_transactions_size = borsh::to_vec(last_chunk_transactions)?.len();
tracing::info!("boom prepare_transactions from validator");
runtime_adapter.prepare_transactions(
storage_config,
crate::types::PrepareTransactionsChunkContext {
Expand Down Expand Up @@ -718,12 +720,36 @@ pub fn validate_chunk_state_witness(
child_shard_uid,
) => {
let old_root = *chunk_extra.state_root();
let trie = Trie::from_recorded_storage(
PartialStorage { nodes: transition.base_state },
old_root,
true,
);
let new_root = trie.retain_split_shard(&boundary_account, retain_mode)?;
let partial_storage = PartialStorage { nodes: transition.base_state };
let parent_trie = Trie::from_recorded_storage(partial_storage, old_root, true);

// Update the congestion info based on the parent shard. It's
// important to do this step before the `retain_split_shard`
// because only the parent has the needed information.
if let Some(congestion_info) = chunk_extra.congestion_info_mut() {
// Get the congestion info based on the parent shard.
let epoch_id = epoch_manager.get_epoch_id(&block_hash)?;
let parent_shard_layout = epoch_manager.get_shard_layout(&epoch_id)?;
let parent_congestion_info = *congestion_info;
*congestion_info = ReshardingManager::get_child_congestion_info_not_finalized(
&parent_trie,
&parent_shard_layout,
parent_congestion_info,
retain_mode,
)?;

// Set the allowed shard based on the child shard.
let next_epoch_id = epoch_manager.get_next_epoch_id(&block_hash)?;
let next_shard_layout = epoch_manager.get_shard_layout(&next_epoch_id)?;
ReshardingManager::finalize_allowed_shard(
&next_shard_layout,
child_shard_uid,
congestion_info,
)?;
}

let new_root = parent_trie.retain_split_shard(&boundary_account, retain_mode)?;

(child_shard_uid, new_root)
}
};
Expand Down
4 changes: 3 additions & 1 deletion chain/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ impl Client {
} else {
0
};
tracing::info!("boom prepare_transactions from client");
runtime.prepare_transactions(
storage_config,
PrepareTransactionsChunkContext {
Expand Down Expand Up @@ -2255,7 +2256,8 @@ impl Client {
validators.remove(account_id);
}
for validator in validators {
trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx, ?validator, ?shard_id, "Routing a transaction");
let tx_hash = tx.get_hash();
trace!(target: "client", me = ?signer.as_ref().map(|bp| bp.validator_id()), ?tx_hash, ?validator, ?shard_id, "Routing a transaction");

// Send message to network to actually forward transaction.
self.network_adapter.send(PeerManagerMessageRequest::NetworkRequests(
Expand Down
29 changes: 25 additions & 4 deletions core/primitives-core/src/version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,18 @@ impl ProtocolFeature {
// TODO(#11201): When stabilizing this feature in mainnet, also remove the temporary code
// that always enables this for mocknet (see config_mocknet function).
ProtocolFeature::ShuffleShardAssignments => 143,
// CurrentEpochStateSync must be enabled before ReshardingV3! When
// releasing this feature please make sure to schedule separate
// protocol upgrades for those features!
ProtocolFeature::CurrentEpochStateSync => 144,
ProtocolFeature::SimpleNightshadeV4 => 145,
// BandwidthScheduler must be enabled before ReshardingV3! When
// releasing this feature please make sure to schedule separate
// protocol upgrades for those features!
ProtocolFeature::BandwidthScheduler => 145,
ProtocolFeature::SimpleNightshadeV4 => 146,
#[cfg(feature = "protocol_feature_relaxed_chunk_validation")]
ProtocolFeature::RelaxedChunkValidation => 146,
ProtocolFeature::ExcludeExistingCodeFromWitnessForCodeLen => 147,
ProtocolFeature::BandwidthScheduler => 148,
ProtocolFeature::RelaxedChunkValidation => 147,
ProtocolFeature::ExcludeExistingCodeFromWitnessForCodeLen => 148,
ProtocolFeature::BlockHeightForReceiptId => 149,
// Place features that are not yet in Nightly below this line.
}
Expand Down Expand Up @@ -345,3 +351,18 @@ macro_rules! checked_feature {
}
}};
}

#[cfg(test)]
mod tests {
use super::ProtocolFeature;

#[test]
fn test_resharding_dependencies() {
let state_sync = ProtocolFeature::CurrentEpochStateSync.protocol_version();
let bandwidth_scheduler = ProtocolFeature::BandwidthScheduler.protocol_version();
let resharding_v3 = ProtocolFeature::SimpleNightshadeV4.protocol_version();

assert!(state_sync < resharding_v3);
assert!(bandwidth_scheduler < resharding_v3);
}
}
10 changes: 6 additions & 4 deletions core/primitives/src/congestion_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,11 @@ impl CongestionInfo {
Ok(())
}

pub fn remove_buffered_receipt_gas(&mut self, gas: Gas) -> Result<(), RuntimeError> {
pub fn remove_buffered_receipt_gas(&mut self, gas: u128) -> Result<(), RuntimeError> {
match self {
CongestionInfo::V1(inner) => {
inner.buffered_receipts_gas =
inner.buffered_receipts_gas.checked_sub(gas as u128).ok_or_else(|| {
inner.buffered_receipts_gas.checked_sub(gas).ok_or_else(|| {
RuntimeError::UnexpectedIntegerOverflow(
"remove_buffered_receipt_gas".into(),
)
Expand Down Expand Up @@ -730,7 +730,8 @@ mod tests {
assert_eq!(config.max_tx_gas, control.process_tx_limit());

// remove halve the congestion
info.remove_buffered_receipt_gas(config.max_congestion_outgoing_gas / 2).unwrap();
let gas_diff = config.max_congestion_outgoing_gas / 2;
info.remove_buffered_receipt_gas(gas_diff.into()).unwrap();
let control = CongestionControl::new(config, info, 0);
assert_eq!(0.5, control.congestion_level());
assert_eq!(
Expand All @@ -741,7 +742,8 @@ mod tests {
assert!(control.shard_accepts_transactions().is_no());

// reduce congestion to 1/8
info.remove_buffered_receipt_gas(3 * config.max_congestion_outgoing_gas / 8).unwrap();
let gas_diff = 3 * config.max_congestion_outgoing_gas / 8;
info.remove_buffered_receipt_gas(gas_diff.into()).unwrap();
let control = CongestionControl::new(config, info, 0);
assert_eq!(0.125, control.congestion_level());
assert_eq!(
Expand Down
Loading

0 comments on commit e3be29b

Please sign in to comment.