From 1704c55154171a63c711353d20e401e6af5c68ef Mon Sep 17 00:00:00 2001 From: Marcelo Diop-Gonzalez Date: Mon, 6 Jan 2025 23:52:06 -0500 Subject: [PATCH] feat(state-dumper): rewrite the state dumper logic (#12492) The current state dumper code is sort of difficult to follow, and doesn't make good use of the available cores to obtain and upload parts. It starts one thread per shard that dumps one part on each iteration of a big loop (that includes a good amount of unnecessary/redundant lookups and calculations). So here we rewrite the logic so that instead of starting one thread per shard and looping over part IDs like that, we just figure out what parts need to be dumped when we see a new epoch, and then spawn futures to obtain and upload the parts. Now the part upload speed will be limited by the number of allowed "obtain part" tasks (4), and the speed of generating those parts. This has the advantage of not needing to change anything to work with dynamic resharding, and the part upload is much faster. On a forknet run with recent mainnet state, the old dumper takes around an hour and a half to dump all the parts, and this version takes around half an hour (could maybe be improved by tweaking/making configurable the number of allowed tasks obtaining parts at a time) This could be refactored further because there's still some leftover structures from the previous implementation that don't fit super cleanly, but this can be done in a future PR. --- Cargo.lock | 1 + chain/chain/src/store/mod.rs | 39 +- chain/client/src/sync/state/shard.rs | 20 +- core/async/src/futures.rs | 18 + integration-tests/src/test_loop/builder.rs | 1 + .../src/tests/client/state_dump.rs | 3 + nearcore/Cargo.toml | 1 + nearcore/src/lib.rs | 4 +- nearcore/src/state_sync.rs | 1329 +++++++++++------ 9 files changed, 917 insertions(+), 499 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f67435739dc..8c2bf77bf3b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5504,6 +5504,7 @@ dependencies = [ "testlib", "thiserror 2.0.0", "tokio", + "tokio-stream", "tracing", "xz2", ] diff --git a/chain/chain/src/store/mod.rs b/chain/chain/src/store/mod.rs index 0a59b46c640..2affea8801e 100644 --- a/chain/chain/src/store/mod.rs +++ b/chain/chain/src/store/mod.rs @@ -1025,16 +1025,35 @@ impl ChainStore { key } - /// Retrieves STATE_SYNC_DUMP for the given shard. - pub fn get_state_sync_dump_progress( - &self, - shard_id: ShardId, - ) -> Result { - option_to_not_found( - self.store - .get_ser(DBCol::BlockMisc, &ChainStore::state_sync_dump_progress_key(shard_id)), - format!("STATE_SYNC_DUMP:{}", shard_id), - ) + /// For each value stored, this returs an (EpochId, bool), where the bool tells whether it's finished + /// because those are the only fields we really care about. + pub fn iter_state_sync_dump_progress<'a>( + &'a self, + ) -> impl Iterator> + 'a { + self.store + .iter_prefix_ser::(DBCol::BlockMisc, STATE_SYNC_DUMP_KEY) + .map(|item| { + item.and_then(|(key, progress)| { + // + 1 for the ':' + let prefix_len = STATE_SYNC_DUMP_KEY.len() + 1; + let int_part = &key[prefix_len..]; + let int_part = int_part.try_into().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Bad StateSyncDump columnn key length: {}", key.len()), + ) + })?; + let shard_id = ShardId::from_le_bytes(int_part); + Ok(( + shard_id, + match progress { + StateSyncDumpProgress::AllDumped { epoch_id, .. } => (epoch_id, true), + StateSyncDumpProgress::InProgress { epoch_id, .. } => (epoch_id, false), + StateSyncDumpProgress::Skipped { epoch_id, .. } => (epoch_id, true), + }, + )) + }) + }) } /// Updates STATE_SYNC_DUMP for the given shard. diff --git a/chain/client/src/sync/state/shard.rs b/chain/client/src/sync/state/shard.rs index 12ca94715d3..ae65276bea8 100644 --- a/chain/client/src/sync/state/shard.rs +++ b/chain/client/src/sync/state/shard.rs @@ -3,7 +3,7 @@ use super::task_tracker::TaskTracker; use crate::metrics; use crate::sync::state::chain_requests::ChainFinalizationRequest; use futures::{StreamExt, TryStreamExt}; -use near_async::futures::{FutureSpawner, FutureSpawnerExt}; +use near_async::futures::{respawn_for_parallelism, FutureSpawner}; use near_async::messaging::AsyncSender; use near_chain::types::RuntimeAdapter; use near_chain::BlockHeader; @@ -280,21 +280,3 @@ async fn apply_state_part( )?; Ok(()) } - -/// Given a future, respawn it as an equivalent future but which does not block the -/// driver of the future. For example, if the given future directly performs -/// computation, normally the whoever drives the future (such as a buffered_unordered) -/// would be blocked by the computation, thereby not allowing computation of other -/// futures driven by the same driver to proceed. This function respawns the future -/// onto the FutureSpawner, so the driver of the returned future would not be blocked. -fn respawn_for_parallelism( - future_spawner: &dyn FutureSpawner, - name: &'static str, - f: impl std::future::Future + Send + 'static, -) -> impl std::future::Future + Send + 'static { - let (sender, receiver) = tokio::sync::oneshot::channel(); - future_spawner.spawn(name, async move { - sender.send(f.await).ok(); - }); - async move { receiver.await.unwrap() } -} diff --git a/core/async/src/futures.rs b/core/async/src/futures.rs index 196a2086da0..14bb3f1a3a5 100644 --- a/core/async/src/futures.rs +++ b/core/async/src/futures.rs @@ -42,6 +42,24 @@ impl FutureSpawnerExt for dyn FutureSpawner + '_ { } } +/// Given a future, respawn it as an equivalent future but which does not block the +/// driver of the future. For example, if the given future directly performs +/// computation, normally the whoever drives the future (such as a buffered_unordered) +/// would be blocked by the computation, thereby not allowing computation of other +/// futures driven by the same driver to proceed. This function respawns the future +/// onto the FutureSpawner, so the driver of the returned future would not be blocked. +pub fn respawn_for_parallelism( + future_spawner: &dyn FutureSpawner, + name: &'static str, + f: impl std::future::Future + Send + 'static, +) -> impl std::future::Future + Send + 'static { + let (sender, receiver) = tokio::sync::oneshot::channel(); + future_spawner.spawn(name, async move { + sender.send(f.await).ok(); + }); + async move { receiver.await.unwrap() } +} + /// A FutureSpawner that hands over the future to Actix. pub struct ActixFutureSpawner; diff --git a/integration-tests/src/test_loop/builder.rs b/integration-tests/src/test_loop/builder.rs index 377d2fee833..144a7677c00 100644 --- a/integration-tests/src/test_loop/builder.rs +++ b/integration-tests/src/test_loop/builder.rs @@ -756,6 +756,7 @@ impl TestLoopBuilder { future_spawner.spawn_boxed("state_sync_dumper", future); Box::new(|| {}) }), + future_spawner: Arc::new(self.test_loop.future_spawner()), handle: None, }; let state_sync_dumper_handle = self.test_loop.data.register_data(state_sync_dumper); diff --git a/integration-tests/src/tests/client/state_dump.rs b/integration-tests/src/tests/client/state_dump.rs index 5ca06f2d4bb..e3db99a5a9c 100644 --- a/integration-tests/src/tests/client/state_dump.rs +++ b/integration-tests/src/tests/client/state_dump.rs @@ -1,5 +1,6 @@ use assert_matches::assert_matches; +use near_async::futures::ActixFutureSpawner; use near_async::time::{Clock, Duration}; use near_chain::near_chain_primitives::error::QueryError; use near_chain::{ChainGenesis, ChainStoreAccess, Provenance}; @@ -66,6 +67,7 @@ fn slow_test_state_dump() { runtime, validator, dump_future_runner: StateSyncDumper::arbiter_dump_future_runner(), + future_spawner: Arc::new(ActixFutureSpawner), handle: None, }; state_sync_dumper.start().unwrap(); @@ -171,6 +173,7 @@ fn run_state_sync_with_dumped_parts( runtime, validator, dump_future_runner: StateSyncDumper::arbiter_dump_future_runner(), + future_spawner: Arc::new(ActixFutureSpawner), handle: None, }; state_sync_dumper.start().unwrap(); diff --git a/nearcore/Cargo.toml b/nearcore/Cargo.toml index b99d5574ced..931f63bdd7f 100644 --- a/nearcore/Cargo.toml +++ b/nearcore/Cargo.toml @@ -44,6 +44,7 @@ strum.workspace = true tempfile.workspace = true thiserror.workspace = true tokio.workspace = true +tokio-stream.workspace = true tracing.workspace = true xz2.workspace = true diff --git a/nearcore/src/lib.rs b/nearcore/src/lib.rs index 0b255586fa1..f132a7a4174 100644 --- a/nearcore/src/lib.rs +++ b/nearcore/src/lib.rs @@ -389,6 +389,7 @@ pub fn start_with_config_and_synchronization( let state_sync_runtime = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); + let state_sync_spawner = Arc::new(TokioRuntimeFutureSpawner(state_sync_runtime.clone())); let StartClientResult { client_actor, client_arbiter_handle, resharding_handle } = start_client( Clock::real(), config.client_config.clone(), @@ -397,7 +398,7 @@ pub fn start_with_config_and_synchronization( shard_tracker.clone(), runtime.clone(), node_id, - Arc::new(TokioRuntimeFutureSpawner(state_sync_runtime.clone())), + state_sync_spawner.clone(), network_adapter.as_multi_sender(), shards_manager_adapter.as_sender(), config.validator_signer.clone(), @@ -434,6 +435,7 @@ pub fn start_with_config_and_synchronization( runtime, validator: config.validator_signer.clone(), dump_future_runner: StateSyncDumper::arbiter_dump_future_runner(), + future_spawner: state_sync_spawner, handle: None, }; state_sync_dumper.start()?; diff --git a/nearcore/src/state_sync.rs b/nearcore/src/state_sync.rs index 481fc530207..dc824731bd1 100644 --- a/nearcore/src/state_sync.rs +++ b/nearcore/src/state_sync.rs @@ -4,11 +4,11 @@ use actix_rt::Arbiter; use anyhow::Context; use borsh::BorshSerialize; use futures::future::BoxFuture; -use futures::FutureExt; -use itertools::Itertools; -use near_async::time::{Clock, Duration, Instant}; +use futures::{FutureExt, StreamExt}; +use near_async::futures::{respawn_for_parallelism, FutureSpawner}; +use near_async::time::{Clock, Duration, Interval}; use near_chain::types::RuntimeAdapter; -use near_chain::{Chain, ChainGenesis, DoomslugThresholdMode, Error}; +use near_chain::{Chain, ChainGenesis, DoomslugThresholdMode}; use near_chain_configs::{ClientConfig, ExternalStorageLocation, MutableValidatorSigner}; use near_client::sync::external::{ create_bucket_readwrite, external_storage_location, StateFileType, @@ -19,20 +19,25 @@ use near_client::sync::external::{ }; use near_epoch_manager::shard_tracker::ShardTracker; use near_epoch_manager::EpochManagerAdapter; +use near_primitives::block::BlockHeader; use near_primitives::hash::CryptoHash; use near_primitives::state_part::PartId; use near_primitives::state_sync::StateSyncDumpProgress; -use near_primitives::types::{AccountId, EpochHeight, EpochId, ShardId, StateRoot}; -use near_primitives::version::PROTOCOL_VERSION; -use rand::{thread_rng, Rng}; -use std::collections::HashSet; -use std::sync::atomic::AtomicBool; -use std::sync::Arc; +use near_primitives::types::{EpochHeight, EpochId, ShardId, StateRoot}; +use rand::seq::SliceRandom; +use rand::thread_rng; +use std::collections::{HashMap, HashSet}; +use std::i64; +use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; +use std::sync::{Arc, RwLock}; +use tokio::sync::oneshot; +use tokio::sync::Semaphore; /// Time limit per state dump iteration. /// A node must check external storage for parts to dump again once time is up. pub const STATE_DUMP_ITERATION_TIME_LIMIT_SECS: u64 = 300; +// TODO: could refactor this further and just have one "Dumper" struct here pub struct StateSyncDumper { pub clock: Clock, pub client_config: ClientConfig, @@ -45,41 +50,13 @@ pub struct StateSyncDumper { /// Please note that the locked value should not be stored anywhere or passed through the thread boundary. pub validator: MutableValidatorSigner, pub dump_future_runner: Box) -> Box>, + pub future_spawner: Arc, pub handle: Option, } impl StateSyncDumper { - /// Returns all current ShardIDs, plus any that may belong to a future epoch after a protocol upgrade - /// For now we start a thread for each shard ID even if it won't be needed for a long time. - /// TODO(resharding): fix that, and handle the dynamic resharding case. - fn get_all_shard_ids(&self) -> anyhow::Result> { - let chain = Chain::new_for_view_client( - self.clock.clone(), - self.epoch_manager.clone(), - self.shard_tracker.clone(), - self.runtime.clone(), - &self.chain_genesis, - DoomslugThresholdMode::TwoThirds, - false, - ) - .context("failed creating Chain")?; - let epoch_id = chain.head().context("failed getting chain head")?.epoch_id; - let head_protocol_version = self - .epoch_manager - .get_epoch_protocol_version(&epoch_id) - .context("failed getting epoch protocol version")?; - - let mut shard_ids = HashSet::new(); - for protocol_version in head_protocol_version..=PROTOCOL_VERSION { - let shard_layout = - self.epoch_manager.get_shard_layout_from_protocol_version(protocol_version); - shard_ids.extend(shard_layout.shard_ids()); - } - Ok(shard_ids) - } - - /// Starts one a thread per tracked shard. - /// Each started thread will be dumping state parts of a single epoch to external storage. + /// Starts a thread that periodically checks whether any new parts need to be uploaded, and then spawns + /// futures to generate and upload them pub fn start(&mut self) -> anyhow::Result<()> { assert!(self.handle.is_none(), "StateSyncDumper already started"); @@ -115,49 +92,43 @@ impl StateSyncDumper { }, }; - // Determine how many threads to start. - let shard_ids = self.get_all_shard_ids()?; - let chain_id = self.client_config.chain_id.clone(); let keep_running = Arc::new(AtomicBool::new(true)); - // Start a thread for each shard. - let handles = shard_ids - .into_iter() - .map(|shard_id| { - let runtime = self.runtime.clone(); - let chain_genesis = self.chain_genesis.clone(); - // Sadly, `Chain` is not `Send` and each thread needs to create its own `Chain` instance. - let chain = Chain::new_for_view_client( - self.clock.clone(), - self.epoch_manager.clone(), - self.shard_tracker.clone(), - runtime.clone(), - &chain_genesis, - DoomslugThresholdMode::TwoThirds, - false, - ) - .unwrap(); - (self.dump_future_runner)( - state_sync_dump( - self.clock.clone(), - shard_id, - chain, - self.epoch_manager.clone(), - self.shard_tracker.clone(), - runtime.clone(), - chain_id.clone(), - dump_config.restart_dump_for_shards.clone().unwrap_or_default(), - external.clone(), - dump_config.iteration_delay.unwrap_or(Duration::seconds(10)), - self.validator.clone(), - keep_running.clone(), - ) - .boxed(), - ) - }) - .collect(); - self.handle = Some(StateSyncDumpHandle { handles, keep_running }); + let chain = Chain::new_for_view_client( + self.clock.clone(), + self.epoch_manager.clone(), + self.shard_tracker.clone(), + self.runtime.clone(), + &self.chain_genesis, + DoomslugThresholdMode::TwoThirds, + false, + ) + .unwrap(); + if let Some(shards) = dump_config.restart_dump_for_shards.as_ref() { + for shard_id in shards { + chain.chain_store().set_state_sync_dump_progress(*shard_id, None).unwrap(); + tracing::debug!(target: "state_sync_dump", ?shard_id, "Dropped existing progress"); + } + } + let handle = (self.dump_future_runner)( + do_state_sync_dump( + self.clock.clone(), + chain, + self.epoch_manager.clone(), + self.shard_tracker.clone(), + self.runtime.clone(), + chain_id, + external, + dump_config.iteration_delay.unwrap_or(Duration::seconds(10)), + self.validator.clone(), + keep_running.clone(), + self.future_spawner.clone(), + ) + .boxed(), + ); + + self.handle = Some(StateSyncDumpHandle { handle: Some(handle), keep_running }); Ok(()) } @@ -179,7 +150,7 @@ impl StateSyncDumper { /// Holds arbiter handles controlling the lifetime of the spawned threads. pub struct StateSyncDumpHandle { - pub handles: Vec>, + pub handle: Option>, keep_running: Arc, } @@ -192,25 +163,11 @@ impl Drop for StateSyncDumpHandle { impl StateSyncDumpHandle { fn stop(&mut self) { tracing::warn!(target: "state_sync_dump", "Stopping state dumper"); - self.keep_running.store(false, std::sync::atomic::Ordering::Relaxed); - self.handles.drain(..).for_each(|dropper| { - dropper(); - }); + self.keep_running.store(false, Ordering::Relaxed); + self.handle.take().unwrap()() } } -/// Fetches the state sync header from DB and serializes it. -fn get_serialized_header( - shard_id: ShardId, - sync_hash: CryptoHash, - chain: &Chain, -) -> anyhow::Result> { - let header = chain.get_state_response_header(shard_id, sync_hash)?; - let mut buffer: Vec = Vec::new(); - header.serialize(&mut buffer)?; - Ok(buffer) -} - pub fn extract_part_id_from_part_file_name(file_name: &String) -> u64 { assert!(is_part_filename(file_name)); return get_part_id_from_filename(file_name).unwrap(); @@ -218,12 +175,15 @@ pub fn extract_part_id_from_part_file_name(file_name: &String) -> u64 { async fn get_missing_part_ids_for_epoch( shard_id: ShardId, - chain_id: &String, + chain_id: &str, epoch_id: &EpochId, epoch_height: u64, total_parts: u64, external: &ExternalConnection, -) -> Result, anyhow::Error> { +) -> Result, anyhow::Error> { + if total_parts == 0 { + return Ok(HashSet::new()); + } let directory_path = external_storage_location_directory( chain_id, epoch_id, @@ -237,432 +197,863 @@ async fn get_missing_part_ids_for_epoch( .iter() .map(|file_name| extract_part_id_from_part_file_name(file_name)) .collect(); - let missing_nums: Vec = + let missing_nums: HashSet<_> = (0..total_parts).filter(|i| !existing_nums.contains(i)).collect(); let num_missing = missing_nums.len(); tracing::debug!(target: "state_sync_dump", ?num_missing, ?directory_path, "Some parts have already been dumped."); Ok(missing_nums) } else { tracing::debug!(target: "state_sync_dump", ?total_parts, ?directory_path, "No part has been dumped."); - let missing_nums = (0..total_parts).collect::>(); + let missing_nums = (0..total_parts).collect(); Ok(missing_nums) } } -fn select_random_part_id_with_index(parts_to_be_dumped: &Vec) -> (u64, usize) { - let mut rng = thread_rng(); - let selected_idx = rng.gen_range(0..parts_to_be_dumped.len()); - let selected_element = parts_to_be_dumped[selected_idx]; - tracing::debug!(target: "state_sync_dump", ?selected_element, "selected parts to dump: "); - (selected_element, selected_idx) +// State associated with dumping a shard's state +struct ShardDump { + state_root: StateRoot, + // None if it's already been dumped + header_to_dump: Option>, + num_parts: u64, + parts_dumped: Arc, + // This is the set of parts who have an associated file stored in the ExternalConnection, + // meaning they've already been dumped. We periodically check this (since other processes/machines + // might have uploaded parts that we didn't) and avoid duplicating work for those parts that have already been updated. + parts_missing: Arc>>, + // This will give Ok(()) when they're all done, or Err() when one gives an error + // For now the tasks never fail, since we just retry all errors like the old implementation did, + // but we probably want to make a change to distinguish which errors are actually retriable + // (e.g. the state snapshot isn't ready yet) + upload_parts: oneshot::Receiver>, } -enum StateDumpAction { - Wait, - Dump { epoch_id: EpochId, epoch_height: EpochHeight, sync_hash: CryptoHash }, +// State associated with dumping an epoch's state +struct DumpState { + epoch_id: EpochId, + epoch_height: EpochHeight, + sync_prev_prev_hash: CryptoHash, + // Contains state for each shard we need to dump. We remove shard IDs from + // this map as we finish them. + dump_state: HashMap, + canceled: Arc, } -fn get_current_state( - chain: &Chain, - shard_id: &ShardId, - shard_tracker: &ShardTracker, - account_id: &Option, - epoch_manager: Arc, -) -> Result { - let was_last_epoch_done = match chain.chain_store().get_state_sync_dump_progress(*shard_id) { - Ok(StateSyncDumpProgress::AllDumped { epoch_id, .. }) => Some(epoch_id), - Ok(StateSyncDumpProgress::Skipped { epoch_id, .. }) => Some(epoch_id), - _ => None, - }; - - let maybe_latest_epoch_info = get_latest_epoch(shard_id, &chain, epoch_manager.clone()); - let latest_epoch_info = match maybe_latest_epoch_info { - Ok(latest_epoch_info) => latest_epoch_info, - Err(err) => { - tracing::error!(target: "state_sync_dump", ?shard_id, ?err, "Failed to get the latest epoch"); - return Err(err); +impl DumpState { + /// For each shard, checks the filenames that exist in `external` and sets the corresponding `parts_missing` fields + /// to contain the parts that haven't yet been uploaded, so that we only try to generate those. + async fn set_missing_parts(&self, external: &ExternalConnection, chain_id: &str) { + for (shard_id, s) in self.dump_state.iter() { + match get_missing_part_ids_for_epoch( + *shard_id, + chain_id, + &self.epoch_id, + self.epoch_height, + s.num_parts, + external, + ) + .await + { + Ok(missing) => { + *s.parts_missing.write().unwrap() = missing; + } + Err(error) => { + tracing::error!(target: "state_sync_dump", ?error, ?shard_id, "Failed to list stored state parts."); + } + } } - }; - let Some(LatestEpochInfo { - epoch_id: new_epoch_id, - epoch_height: new_epoch_height, - sync_hash: new_sync_hash, - }) = latest_epoch_info - else { - return Ok(StateDumpAction::Wait); - }; - - if Some(&new_epoch_id) == was_last_epoch_done.as_ref() { - return Ok(StateDumpAction::Wait); } +} - let shard_layout = epoch_manager.get_shard_layout(&new_epoch_id)?; +// Represents the state of the current epoch's state part dump +enum CurrentDump { + None, + InProgress(DumpState), + Done(EpochId), +} - if shard_layout.shard_ids().contains(shard_id) - && cares_about_shard(chain, shard_id, &new_sync_hash, &shard_tracker, &account_id)? - { - Ok(StateDumpAction::Dump { - epoch_id: new_epoch_id, - epoch_height: new_epoch_height, - sync_hash: new_sync_hash, - }) - } else { - Ok(StateDumpAction::Wait) - } +// Helper type used as an intermediate return value where the caller will want the sender only +// if there's something to do +enum NewDump { + Dump(DumpState, HashMap>>), + NoTrackedShards, } -/// Uploads header to external storage. -/// Returns true if it was successful. -async fn upload_state_header( - chain_id: &String, - epoch_id: &EpochId, - epoch_height: u64, +/// State associated with dumps for all shards responsible for checking when we need to dump for a new epoch +/// The basic flow is as follows: +/// +/// At startup or when we enter a new epoch, we initialize the `current_dump` field to represent the current epoch's state dump. +/// Then for each shard that we track and want to dump state for, we'll have one `ShardDump` struct representing it stored in the +/// `DumpState` struct that holds the global state. First we upload headers if they're not already present in the external storage, and +/// then we start the part uploading by calling `start_upload_parts()`. This initializes one `PartUploader` struct for each shard_id and part_id, +/// and spawns a PartUploader::upload_state_part() future for each, that will be responsible for generating and uploading that part if it's not +/// already uploaded. When all the parts for a shard have been uploaded, we'll be notified by the `upload_parts` field of the associated +/// `ShardDump` struct, which we check in `check_parts_upload()`. +/// +/// Separately, every so often we check whether there's a new epoch to dump state for (in `check_head()`) and whether other processes +/// have uploaded some state parts that we can therfore skip (in `check_stored_parts()`). +struct StateDumper { + clock: Clock, + chain_id: String, + validator: MutableValidatorSigner, + shard_tracker: ShardTracker, + chain: Chain, + epoch_manager: Arc, + runtime: Arc, + // State associated with dumping the current epoch + current_dump: CurrentDump, + external: ExternalConnection, + future_spawner: Arc, + // Used to limit how many tasks can be doing the computation-heavy state part generation at a time + obtain_parts: Arc, +} + +// Stores needed data for use in part upload futures +struct PartUploader { + clock: Clock, + external: ExternalConnection, + runtime: Arc, + chain_id: String, + epoch_id: EpochId, + epoch_height: EpochHeight, + sync_prev_prev_hash: CryptoHash, shard_id: ShardId, - state_sync_header: anyhow::Result>, - external: &ExternalConnection, -) -> bool { - match state_sync_header { - Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to serialize header."); - false + state_root: StateRoot, + num_parts: u64, + // Used for setting the num_parts_dumped gauge metric (which is an i64) + // When part upload tasks are cancelled on a new epoch, this is set to -1 so tasks + // know not to touch that metric anymore. + parts_dumped: Arc, + parts_missing: Arc>>, + obtain_parts: Arc, + canceled: Arc, +} + +impl PartUploader { + /// Increment the STATE_SYNC_DUMP_NUM_PARTS_DUMPED metric + fn inc_parts_dumped(&self) { + match self.parts_dumped.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |prev| { + if prev >= 0 { + Some(prev + 1) + } else { + None + } + }) { + Ok(prev_parts_dumped) => { + metrics::STATE_SYNC_DUMP_NUM_PARTS_DUMPED + .with_label_values(&[&self.shard_id.to_string()]) + .set(prev_parts_dumped + 1); + } + Err(_) => {} + }; + } + + /// Attempt to generate the state part for `self.epoch_id`, `self.shard_id` and `part_idx`, and upload it to + /// the external storage. The state part generation is limited by the number of permits allocated to the `obtain_parts` + /// Semaphore. For now, this always returns OK(()) (loops forever retrying in case of errors), but this should be changed + /// to return Err() if the error is not going to be retriable. + async fn upload_state_part(self: Arc, part_idx: u64) -> anyhow::Result<()> { + if !self.parts_missing.read().unwrap().contains(&part_idx) { + self.inc_parts_dumped(); + return Ok(()); } - Ok(header) => { - let file_type = StateFileType::StateHeader; - let location = - external_storage_location(&chain_id, &epoch_id, epoch_height, shard_id, &file_type); - match external.put_file(file_type, &header, shard_id, &location).await { - Err(err) => { - tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, ?err, "Failed to put header into external storage. Will retry next iteration."); - false + let part_id = PartId::new(part_idx, self.num_parts); + + let state_part = loop { + if self.canceled.load(Ordering::Relaxed) { + return Ok(()); + } + let _timer = metrics::STATE_SYNC_DUMP_ITERATION_ELAPSED + .with_label_values(&[&self.shard_id.to_string()]) + .start_timer(); + let state_part = { + let _permit = self.obtain_parts.acquire().await.unwrap(); + self.runtime.obtain_state_part( + self.shard_id, + &self.sync_prev_prev_hash, + &self.state_root, + part_id, + ) + }; + match state_part { + Ok(state_part) => { + break state_part; } - Ok(_) => { - tracing::trace!(target: "state_sync_dump", ?shard_id, epoch_height, "Header saved to external storage."); - true + Err(error) => { + // TODO: return non retriable errors. + tracing::warn!( + target: "state_sync_dump", + shard_id = %self.shard_id, epoch_height=%self.epoch_height, epoch_id=?&self.epoch_id, ?part_id, ?error, + "Failed to obtain state part. Retrying in 200 millis." + ); + self.clock.sleep(Duration::milliseconds(200)).await; + continue; + } + } + }; + + let file_type = StateFileType::StatePart { part_id: part_idx, num_parts: self.num_parts }; + let location = external_storage_location( + &self.chain_id, + &self.epoch_id, + self.epoch_height, + self.shard_id, + &file_type, + ); + loop { + if self.canceled.load(Ordering::Relaxed) { + return Ok(()); + } + match self + .external + .put_file(file_type.clone(), &state_part, self.shard_id, &location) + .await + { + Ok(()) => { + self.inc_parts_dumped(); + metrics::STATE_SYNC_DUMP_SIZE_TOTAL + .with_label_values(&[ + &self.epoch_height.to_string(), + &self.shard_id.to_string(), + ]) + .inc_by(state_part.len() as u64); + tracing::debug!(target: "state_sync_dump", shard_id = %self.shard_id, epoch_height=%self.epoch_height, epoch_id=?&self.epoch_id, ?part_id, "Uploaded state part."); + return Ok(()); + } + Err(error) => { + tracing::warn!( + target: "state_sync_dump", shard_id = %self.shard_id, epoch_height=%self.epoch_height, epoch_id=?&self.epoch_id, ?part_id, ?error, + "Failed to upload state part. Retrying in 200 millis." + ); + self.clock.sleep(Duration::milliseconds(200)).await; + continue; } } } } } -const FAILURES_ALLOWED_PER_ITERATION: u32 = 10; - -async fn state_sync_dump( +// Stores needed data for use in header upload futures +struct HeaderUploader { clock: Clock, - shard_id: ShardId, - chain: Chain, - epoch_manager: Arc, - shard_tracker: ShardTracker, - runtime: Arc, - chain_id: String, - restart_dump_for_shards: Vec, external: ExternalConnection, - iteration_delay: Duration, - validator: MutableValidatorSigner, - keep_running: Arc, -) { - tracing::info!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop"); + chain_id: String, + epoch_id: EpochId, + epoch_height: EpochHeight, +} - if restart_dump_for_shards.contains(&shard_id) { - tracing::debug!(target: "state_sync_dump", ?shard_id, "Dropped existing progress"); - chain.chain_store().set_state_sync_dump_progress(shard_id, None).unwrap(); +impl HeaderUploader { + /// Attempt to generate the state header for `self.epoch_id` and `self.shard_id`, and upload it to + /// the external storage. For now, this always returns OK(()) (loops forever retrying in case of errors), + /// but this should be changed to return Err() if the error is not going to be retriable. + async fn upload_header(self: Arc, shard_id: ShardId, header: Option>) { + let Some(header) = header else { + return; + }; + let file_type = StateFileType::StateHeader; + let location = external_storage_location( + &self.chain_id, + &self.epoch_id, + self.epoch_height, + shard_id, + &file_type, + ); + loop { + match self.external.put_file(file_type.clone(), &header, shard_id, &location).await { + Ok(_) => { + tracing::info!( + target: "state_sync_dump", %shard_id, epoch_height = %self.epoch_height, + "Header saved to external storage." + ); + return; + } + Err(err) => { + tracing::warn!( + target: "state_sync_dump", %shard_id, epoch_height = %self.epoch_height, ?err, + "Failed to put header into external storage. Will retry next iteration." + ); + self.clock.sleep(Duration::seconds(5)).await; + continue; + } + }; + } } - // Stop if the node is stopped. - // Note that without this check the state dumping thread is unstoppable, i.e. non-interruptable. - while keep_running.load(std::sync::atomic::Ordering::Relaxed) { - tracing::debug!(target: "state_sync_dump", ?shard_id, "Running StateSyncDump loop iteration"); - let account_id = validator.get().map(|v| v.validator_id().clone()); - let current_state = get_current_state( - &chain, - &shard_id, - &shard_tracker, - &account_id, - epoch_manager.clone(), - ); - let next_state = match current_state { + /// Returns whether the state sync header for `self.epoch_id` and `self.shard_id` is already uploaded to the + /// external storage + async fn header_stored(self: Arc, shard_id: ShardId) -> bool { + match self + .external + .is_state_sync_header_stored_for_epoch( + shard_id, + &self.chain_id, + &self.epoch_id, + self.epoch_height, + ) + .await + { + Ok(stored) => stored, Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to get the current state"); - None + tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to determine header presence in external storage."); + false } - Ok(StateDumpAction::Wait) => None, - Ok(StateDumpAction::Dump { epoch_id, epoch_height, sync_hash }) => { - let in_progress_data = get_in_progress_data(shard_id, sync_hash, &chain); - match in_progress_data { - Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to get in progress data"); - None - } - Ok((state_root, num_parts, sync_prev_prev_hash)) => { - // Upload header - let header_in_external_storage = match external - .is_state_sync_header_stored_for_epoch( - shard_id, - &chain_id, - &epoch_id, - epoch_height, - ) - .await - { - Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to determine header presence in external storage."); - false - } - // Header is already stored - Ok(true) => true, - // Header is missing - Ok(false) => { - upload_state_header( - &chain_id, - &epoch_id, - epoch_height, - shard_id, - get_serialized_header(shard_id, sync_hash, &chain), - &external, - ) - .await - } - }; - - let header_upload_status = if header_in_external_storage { - None - } else { - Some(StateSyncDumpProgress::InProgress { - epoch_id: epoch_id, - epoch_height, - sync_hash, - }) - }; - - // Upload parts - let parts_upload_status = match get_missing_part_ids_for_epoch( - shard_id, - &chain_id, - &epoch_id, - epoch_height, - num_parts, - &external, - ) - .await - { - Err(err) => { - tracing::error!(target: "state_sync_dump", ?err, ?shard_id, "Failed to determine missing parts"); - None - } - Ok(missing_parts) if missing_parts.is_empty() => { - update_dumped_size_and_cnt_metrics( - &shard_id, - epoch_height, - None, - num_parts, - num_parts, - ); - Some(StateSyncDumpProgress::AllDumped { epoch_id, epoch_height }) - } - Ok(missing_parts) => { - let mut parts_to_dump = missing_parts.clone(); - let timer = Instant::now(); - let mut dumped_any_state_part = false; - let mut failures_cnt = 0; - // Stop if the node is stopped. - // Note that without this check the state dumping thread is unstoppable, i.e. non-interruptable. - while keep_running.load(std::sync::atomic::Ordering::Relaxed) - && timer.elapsed().as_secs() - <= STATE_DUMP_ITERATION_TIME_LIMIT_SECS - && !parts_to_dump.is_empty() - && failures_cnt < FAILURES_ALLOWED_PER_ITERATION - { - let _timer = metrics::STATE_SYNC_DUMP_ITERATION_ELAPSED - .with_label_values(&[&shard_id.to_string()]) - .start_timer(); - - let (part_id, selected_idx) = - select_random_part_id_with_index(&parts_to_dump); - - let state_part = runtime.obtain_state_part( - shard_id, - &sync_prev_prev_hash, - &state_root, - PartId::new(part_id, num_parts), - ); - let state_part = match state_part { - Ok(state_part) => state_part, - Err(err) => { - tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to obtain and store part. Will skip this part."); - failures_cnt += 1; - continue; - } - }; - - let file_type = StateFileType::StatePart { part_id, num_parts }; - let location = external_storage_location( - &chain_id, - &epoch_id, - epoch_height, - shard_id, - &file_type, - ); - if let Err(err) = external - .put_file(file_type, &state_part, shard_id, &location) - .await - { - // no need to break if there's an error, we should keep dumping other parts. - // reason is we are dumping random selected parts, so it's fine if we are not able to finish all of them - tracing::warn!(target: "state_sync_dump", ?shard_id, epoch_height, part_id, ?err, "Failed to put a store part into external storage. Will skip this part."); - failures_cnt += 1; - continue; - } - - // Remove the dumped part from parts_to_dump so that we draw without replacement. - parts_to_dump.swap_remove(selected_idx); - update_dumped_size_and_cnt_metrics( - &shard_id, - epoch_height, - Some(state_part.len()), - num_parts.checked_sub(parts_to_dump.len() as u64).unwrap(), - num_parts, - ); - dumped_any_state_part = true; - } - if parts_to_dump.is_empty() { - Some(StateSyncDumpProgress::AllDumped { - epoch_id, - epoch_height, - }) - } else if dumped_any_state_part { - Some(StateSyncDumpProgress::InProgress { - epoch_id, - epoch_height, - sync_hash, - }) - } else { - // No progress made. Wait before retrying. - None - } - } - }; - match (&parts_upload_status, &header_upload_status) { - ( - Some(StateSyncDumpProgress::AllDumped { .. }), - Some(StateSyncDumpProgress::InProgress { .. }), - ) => header_upload_status, - _ => parts_upload_status, - } - } - } + } + } +} + +impl StateDumper { + fn new( + clock: Clock, + chain_id: String, + validator: MutableValidatorSigner, + shard_tracker: ShardTracker, + chain: Chain, + epoch_manager: Arc, + runtime: Arc, + external: ExternalConnection, + future_spawner: Arc, + ) -> Self { + Self { + clock, + chain_id, + validator, + shard_tracker, + chain, + epoch_manager, + runtime, + current_dump: CurrentDump::None, + external, + future_spawner, + obtain_parts: Arc::new(Semaphore::new(4)), + } + } + + fn get_block_header(&self, hash: &CryptoHash) -> anyhow::Result { + self.chain.get_block_header(hash).with_context(|| format!("Failed getting header {}", hash)) + } + + /// Reads the DB entries starting with `STATE_SYNC_DUMP_KEY`, and checks which ShardIds and EpochIds are indicated as + /// already having been fully dumped. For each shard ID whose state for `epoch_id` has already been dumped, we remove it + /// from `dump` and `senders` so that we don't start the state dump logic for it. + fn check_old_progress( + &mut self, + epoch_id: &EpochId, + dump: &mut DumpState, + senders: &mut HashMap>>, + ) -> anyhow::Result<()> { + for res in self.chain.chain_store().iter_state_sync_dump_progress() { + let (shard_id, (dumped_epoch_id, done)) = + res.context("failed iterating over stored dump progress")?; + if &dumped_epoch_id != epoch_id { + self.chain + .chain_store() + .set_state_sync_dump_progress(shard_id, None) + .context("failed setting state dump progress")?; + } else if done { + dump.dump_state.remove(&shard_id); + senders.remove(&shard_id); } + } + Ok(()) + } + + /// Returns the `sync_hash` header corresponding to the latest final block if it's already known. + fn latest_sync_header(&self) -> anyhow::Result> { + let head = self.chain.head().context("Failed getting chain head")?; + let header = self.get_block_header(&head.last_block_hash)?; + let final_hash = header.last_final_block(); + if final_hash == &CryptoHash::default() { + return Ok(None); + } + let Some(sync_hash) = self + .chain + .get_sync_hash(final_hash) + .with_context(|| format!("Failed getting sync hash for {}", &final_hash))? + else { + return Ok(None); }; + self.get_block_header(&sync_hash).map(Some) + } - // Record the next state of the state machine. - let has_progress = match next_state { - Some(next_state) => { - tracing::debug!(target: "state_sync_dump", ?shard_id, ?next_state); - match chain.chain_store().set_state_sync_dump_progress(shard_id, Some(next_state)) { - Ok(_) => true, - Err(err) => { - // This will be retried. - tracing::debug!(target: "state_sync_dump", ?shard_id, ?err, "Failed to set progress"); - false - } - } + /// Generates the state sync header for the shard and initializes the `ShardDump` struct which + /// will be used to keep track of what's been dumped so far for this shard. + fn get_shard_dump( + &self, + shard_id: ShardId, + sync_hash: &CryptoHash, + ) -> anyhow::Result<(ShardDump, oneshot::Sender>)> { + let state_header = + self.chain.get_state_response_header(shard_id, *sync_hash).with_context(|| { + format!("Failed getting state response header for {} {}", shard_id, sync_hash) + })?; + let state_root = state_header.chunk_prev_state_root(); + let num_parts = state_header.num_state_parts(); + metrics::STATE_SYNC_DUMP_NUM_PARTS_TOTAL + .with_label_values(&[&shard_id.to_string()]) + .set(num_parts.try_into().unwrap_or(i64::MAX)); + + let mut header_bytes: Vec = Vec::new(); + state_header.serialize(&mut header_bytes)?; + let (sender, receiver) = oneshot::channel(); + Ok(( + ShardDump { + state_root, + header_to_dump: Some(header_bytes), + num_parts, + parts_dumped: Arc::new(AtomicI64::new(0)), + parts_missing: Arc::new(RwLock::new((0..num_parts).collect())), + upload_parts: receiver, + }, + sender, + )) + } + + /// Initializes a `NewDump` struct, which is a helper return value that either returns `NoTrackedShards` + /// if we're not tracking anything, or a `DumpState` struct, which holds one `ShardDump` initialized by `get_shard_dump()` + /// for each shard that we track. This, and the associated oneshot::Senders will then hold all the state related to the + /// progress of dumping the current epoch's state. This is to be called at startup and also upon each new epoch. + fn get_dump_state(&mut self, sync_header: &BlockHeader) -> anyhow::Result { + let epoch_info = self + .epoch_manager + .get_epoch_info(sync_header.epoch_id()) + .with_context(|| format!("Failed getting epoch info {:?}", sync_header.epoch_id()))?; + let sync_prev_header = self.get_block_header(sync_header.prev_hash())?; + let sync_prev_prev_hash = *sync_prev_header.prev_hash(); + let shard_ids = self + .epoch_manager + .shard_ids(sync_header.epoch_id()) + .with_context(|| format!("Failed getting shard IDs {:?}", sync_header.epoch_id()))?; + + let v = self.validator.get(); + let account_id = v.as_ref().map(|v| v.validator_id()); + let mut dump_state = HashMap::new(); + let mut senders = HashMap::new(); + for shard_id in shard_ids { + if !self.shard_tracker.care_about_shard( + account_id, + sync_header.prev_hash(), + shard_id, + true, + ) { + tracing::debug!( + target: "state_sync_dump", epoch_height = %epoch_info.epoch_height(), epoch_id = ?sync_header.epoch_id(), %shard_id, + "Not dumping state for non-tracked shard." + ); + continue; } - None => { - // Nothing to do, will check again later. - tracing::debug!(target: "state_sync_dump", ?shard_id, "Idle"); - false + metrics::STATE_SYNC_DUMP_EPOCH_HEIGHT + .with_label_values(&[&shard_id.to_string()]) + .set(epoch_info.epoch_height().try_into().unwrap_or(i64::MAX)); + + let (shard_dump, sender) = self.get_shard_dump(shard_id, sync_header.hash())?; + dump_state.insert(shard_id, shard_dump); + senders.insert(shard_id, sender); + } + assert_eq!( + dump_state.keys().collect::>(), + senders.keys().collect::>() + ); + if dump_state.is_empty() { + tracing::warn!( + target: "state_sync_dump", epoch_height = %epoch_info.epoch_height(), epoch_id = ?sync_header.epoch_id(), + "Not doing anything for the current epoch. No shards tracked." + ); + return Ok(NewDump::NoTrackedShards); + } + Ok(NewDump::Dump( + DumpState { + epoch_id: *sync_header.epoch_id(), + epoch_height: epoch_info.epoch_height(), + sync_prev_prev_hash, + dump_state, + canceled: Arc::new(AtomicBool::new(false)), + }, + senders, + )) + } + + /// For each shard we're dumping state for, check whether the state sync header is already stored in the external storage, + /// and set `header_to_dump` to None if so, so we don't waste time uploading it again. + async fn check_stored_headers(&mut self, dump: &mut DumpState) -> anyhow::Result<()> { + let uploader = Arc::new(HeaderUploader { + clock: self.clock.clone(), + external: self.external.clone(), + chain_id: self.chain_id.clone(), + epoch_id: dump.epoch_id, + epoch_height: dump.epoch_height, + }); + let shards = dump + .dump_state + .iter() + .map(|(shard_id, _)| (uploader.clone(), *shard_id)) + .collect::>(); + let headers_stored = tokio_stream::iter(shards) + .filter_map(|(uploader, shard_id)| async move { + let stored = uploader.header_stored(shard_id).await; + if stored { + Some(futures::future::ready(shard_id)) + } else { + None + } + }) + .buffer_unordered(10) + .collect::>() + .await; + for shard_id in headers_stored { + tracing::info!( + target: "state_sync_dump", %shard_id, epoch_height = %dump.epoch_height, + "Header already saved to external storage." + ); + let s = dump.dump_state.get_mut(&shard_id).unwrap(); + s.header_to_dump = None; + } + Ok(()) + } + + /// try to upload the state sync header for each shard we're dumping state for + async fn store_headers(&mut self, dump: &mut DumpState) -> anyhow::Result<()> { + let uploader = Arc::new(HeaderUploader { + clock: self.clock.clone(), + external: self.external.clone(), + chain_id: self.chain_id.clone(), + epoch_id: dump.epoch_id, + epoch_height: dump.epoch_height, + }); + let headers = dump + .dump_state + .iter_mut() + .map(|(shard_id, shard_dump)| { + (uploader.clone(), *shard_id, shard_dump.header_to_dump.take()) + }) + .collect::>(); + + tokio_stream::iter(headers) + .map(|(uploader, shard_id, header)| async move { + uploader.upload_header(shard_id, header).await + }) + .buffer_unordered(10) + .collect::<()>() + .await; + + Ok(()) + } + + /// Start uploading state parts. For each shard we're dumping state for and each state part in that shard, this + /// starts one PartUploader::upload_state_part() future. It also starts one future that will examine the results + /// of those futures as they finish, and that will send on `senders` either the first error that occurs or Ok(()) + /// when all parts have been uploaded for the shard. + async fn start_upload_parts( + &mut self, + senders: HashMap>>, + dump: &DumpState, + ) { + let mut senders = senders + .into_iter() + .map(|(shard_id, sender)| { + let d = dump.dump_state.get(&shard_id).unwrap(); + (shard_id, (sender, d.num_parts)) + }) + .collect::>(); + let mut empty_shards = HashSet::new(); + let uploaders = dump + .dump_state + .iter() + .filter_map(|(shard_id, shard_dump)| { + metrics::STATE_SYNC_DUMP_NUM_PARTS_DUMPED + .with_label_values(&[&shard_id.to_string()]) + .set(0); + if shard_dump.num_parts > 0 { + Some(Arc::new(PartUploader { + clock: self.clock.clone(), + external: self.external.clone(), + runtime: self.runtime.clone(), + chain_id: self.chain_id.clone(), + epoch_id: dump.epoch_id, + epoch_height: dump.epoch_height, + sync_prev_prev_hash: dump.sync_prev_prev_hash, + shard_id: *shard_id, + state_root: shard_dump.state_root, + num_parts: shard_dump.num_parts, + parts_dumped: shard_dump.parts_dumped.clone(), + parts_missing: shard_dump.parts_missing.clone(), + obtain_parts: self.obtain_parts.clone(), + canceled: dump.canceled.clone(), + })) + } else { + empty_shards.insert(shard_id); + None + } + }) + .collect::>(); + for shard_id in empty_shards { + let (sender, _) = senders.remove(shard_id).unwrap(); + let _ = sender.send(Ok(())); + } + assert_eq!(senders.len(), uploaders.len()); + + let mut tasks = uploaders + .iter() + .map(|u| (0..u.num_parts).map(|part_id| (u.clone(), part_id))) + .flatten() + .collect::>(); + // We randomize so different nodes uploading parts don't try to upload in the same order + tasks.shuffle(&mut thread_rng()); + + let future_spawner = self.future_spawner.clone(); + let fut = async move { + let mut tasks = tokio_stream::iter(tasks) + .map(|(u, part_id)| { + let shard_id = u.shard_id; + let task = u.upload_state_part(part_id); + let task = respawn_for_parallelism(&*future_spawner, "upload part", task); + async move { (shard_id, task.await) } + }) + .buffer_unordered(10); + + while let Some((shard_id, result)) = tasks.next().await { + let std::collections::hash_map::Entry::Occupied(mut e) = senders.entry(shard_id) + else { + panic!("shard ID {} missing in state dump handles", shard_id); + }; + let (_, parts_left) = e.get_mut(); + if result.is_err() { + let (sender, _) = e.remove(); + let _ = sender.send(result); + return; + } + *parts_left -= 1; + if *parts_left == 0 { + let (sender, _) = e.remove(); + let _ = sender.send(result); + } } }; + self.future_spawner.spawn_boxed("upload_parts", fut.boxed()); + } - if !has_progress { - // Avoid a busy-loop when there is nothing to do. - clock.sleep(iteration_delay).await; + /// Sets the in-memory and on-disk state to reflect that we're currently dumping state for a new epoch, + /// with the info and progress represented in `dump`. + fn new_dump(&mut self, dump: DumpState, sync_hash: CryptoHash) -> anyhow::Result<()> { + for (shard_id, _) in dump.dump_state.iter() { + self.chain + .chain_store() + .set_state_sync_dump_progress( + *shard_id, + Some(StateSyncDumpProgress::InProgress { + epoch_id: dump.epoch_id, + epoch_height: dump.epoch_height, + sync_hash, + }), + ) + .context("failed setting state dump progress")?; } + self.current_dump = CurrentDump::InProgress(dump); + Ok(()) } - tracing::debug!(target: "state_sync_dump", ?shard_id, "Stopped state dump thread"); -} -// Extracts extra data needed for obtaining state parts. -fn get_in_progress_data( - shard_id: ShardId, - sync_hash: CryptoHash, - chain: &Chain, -) -> Result<(StateRoot, u64, CryptoHash), Error> { - let state_header = chain.get_state_response_header(shard_id, sync_hash)?; - let state_root = state_header.chunk_prev_state_root(); - let num_parts = state_header.num_state_parts(); - - let sync_block_header = chain.get_block_header(&sync_hash)?; - let sync_prev_block_header = chain.get_previous_header(&sync_block_header)?; - let sync_prev_prev_hash = sync_prev_block_header.prev_hash(); - Ok((state_root, num_parts, *sync_prev_prev_hash)) -} + // Checks the current epoch and initializes the types associated with dumping its state + // if it hasn't already been dumped. + async fn init(&mut self, iteration_delay: Duration) -> anyhow::Result<()> { + loop { + let Some(sync_header) = self.latest_sync_header()? else { + self.clock.sleep(iteration_delay).await; + continue; + }; + match self.get_dump_state(&sync_header)? { + NewDump::Dump(mut dump, mut senders) => { + self.check_old_progress(sync_header.epoch_id(), &mut dump, &mut senders)?; + if dump.dump_state.is_empty() { + self.current_dump = CurrentDump::Done(*sync_header.epoch_id()); + return Ok(()); + } -fn update_dumped_size_and_cnt_metrics( - shard_id: &ShardId, - epoch_height: EpochHeight, - part_len: Option, - parts_dumped: u64, - num_parts: u64, -) { - if let Some(part_len) = part_len { - metrics::STATE_SYNC_DUMP_SIZE_TOTAL - .with_label_values(&[&epoch_height.to_string(), &shard_id.to_string()]) - .inc_by(part_len as u64); + self.check_stored_headers(&mut dump).await?; + self.store_headers(&mut dump).await?; + + dump.set_missing_parts(&self.external, &self.chain_id).await; + self.start_upload_parts(senders, &dump).await; + self.new_dump(dump, *sync_header.hash())?; + } + NewDump::NoTrackedShards => { + self.current_dump = CurrentDump::Done(*sync_header.epoch_id()); + } + } + return Ok(()); + } } - metrics::STATE_SYNC_DUMP_EPOCH_HEIGHT - .with_label_values(&[&shard_id.to_string()]) - .set(epoch_height as i64); + // Returns when the part upload tasks are finished + async fn check_parts_upload(&mut self) -> anyhow::Result<()> { + let CurrentDump::InProgress(dump) = &mut self.current_dump else { + return std::future::pending().await; + }; + let ((shard_id, result), _, _still_going) = + futures::future::select_all(dump.dump_state.iter_mut().map(|(shard_id, s)| { + async { + let r = (&mut s.upload_parts).await.unwrap(); + (*shard_id, r) + } + .boxed() + })) + .await; + result?; + drop(_still_going); - metrics::STATE_SYNC_DUMP_NUM_PARTS_DUMPED - .with_label_values(&[&shard_id.to_string()]) - .set(parts_dumped as i64); + tracing::info!(target: "state_sync_dump", epoch_id = ?&dump.epoch_id, %shard_id, "Shard dump finished"); - metrics::STATE_SYNC_DUMP_NUM_PARTS_TOTAL - .with_label_values(&[&shard_id.to_string()]) - .set(num_parts as i64); -} + self.chain + .chain_store() + .set_state_sync_dump_progress( + shard_id, + Some(StateSyncDumpProgress::AllDumped { + epoch_id: dump.epoch_id, + epoch_height: dump.epoch_height, + }), + ) + .context("failed setting state dump progress")?; + dump.dump_state.remove(&shard_id); + if dump.dump_state.is_empty() { + self.current_dump = CurrentDump::Done(dump.epoch_id); + } + Ok(()) + } -fn cares_about_shard( - chain: &Chain, - shard_id: &ShardId, - sync_hash: &CryptoHash, - shard_tracker: &ShardTracker, - account_id: &Option, -) -> Result { - let sync_header = chain.get_block_header(&sync_hash)?; - let sync_prev_hash = sync_header.prev_hash(); - Ok(shard_tracker.care_about_shard(account_id.as_ref(), sync_prev_hash, *shard_id, true)) -} + // Checks which parts have already been uploaded possibly by other nodes + // We use &mut so the do_state_sync_dump() future will be Send, which it won't be if we use a normal + // reference because of the Chain field + async fn check_stored_parts(&mut self) { + let CurrentDump::InProgress(dump) = &self.current_dump else { + return; + }; + dump.set_missing_parts(&self.external, &self.chain_id).await; + } -struct LatestEpochInfo { - epoch_id: EpochId, - epoch_height: EpochHeight, - sync_hash: CryptoHash, + /// Check whether there's a new epoch to dump state for. In that case, we start dumping + /// state for the new epoch whether or not we've finished with the old one, since other nodes + /// will be interested in the latest state. + async fn check_head(&mut self) -> anyhow::Result<()> { + let Some(sync_header) = self.latest_sync_header()? else { + return Ok(()); + }; + match &self.current_dump { + CurrentDump::InProgress(dump) => { + if &dump.epoch_id == sync_header.epoch_id() { + return Ok(()); + } + dump.canceled.store(true, Ordering::Relaxed); + for (_shard_id, d) in dump.dump_state.iter() { + // Set it to -1 to tell the existing tasks not to set the metrics anymore + d.parts_dumped.store(-1, Ordering::SeqCst); + } + } + CurrentDump::Done(epoch_id) => { + if epoch_id == sync_header.epoch_id() { + return Ok(()); + } + } + CurrentDump::None => {} + }; + match self.get_dump_state(&sync_header)? { + NewDump::Dump(mut dump, sender) => { + self.store_headers(&mut dump).await?; + self.start_upload_parts(sender, &dump).await; + self.new_dump(dump, *sync_header.hash())?; + } + NewDump::NoTrackedShards => { + self.current_dump = CurrentDump::Done(*sync_header.epoch_id()); + } + }; + Ok(()) + } } -/// return epoch_id and sync_hash of the latest complete epoch available locally. -fn get_latest_epoch( - shard_id: &ShardId, - chain: &Chain, +const CHECK_STORED_PARTS_INTERVAL: Duration = Duration::seconds(20); + +/// Main entry point into the state dumper. Initializes the state dumper and starts a loop that periodically +/// checks whether there's a new epoch to dump state for. +async fn state_sync_dump( + clock: Clock, + chain: Chain, epoch_manager: Arc, -) -> Result, Error> { - let head = chain.head()?; - tracing::debug!(target: "state_sync_dump", ?shard_id, "Check if a new complete epoch is available"); - let hash = head.last_block_hash; - let header = chain.get_block_header(&hash)?; - let final_hash = header.last_final_block(); - if final_hash == &CryptoHash::default() { - return Ok(None); + shard_tracker: ShardTracker, + runtime: Arc, + chain_id: String, + external: ExternalConnection, + iteration_delay: Duration, + validator: MutableValidatorSigner, + keep_running: Arc, + future_spawner: Arc, +) -> anyhow::Result<()> { + tracing::info!(target: "state_sync_dump", "Running StateSyncDump loop"); + + let mut dumper = StateDumper::new( + clock.clone(), + chain_id, + validator, + shard_tracker, + chain, + epoch_manager, + runtime, + external, + future_spawner, + ); + dumper.init(iteration_delay).await?; + + let now = clock.now(); + // This is set to zero in some tests where the block production delay is very small (10 millis). + // In that case we'll actually just wait for 1 millisecond. The previous behavior was to call + // clock.sleep(ZERO), but setting it to 1 is probably fine, and works with the Instant below. + let min_iteration_delay = Duration::milliseconds(1); + let mut check_head = + Interval::new(now + iteration_delay, iteration_delay.max(min_iteration_delay)); + let mut check_stored_parts = + Interval::new(now + CHECK_STORED_PARTS_INTERVAL, CHECK_STORED_PARTS_INTERVAL); + + while keep_running.load(Ordering::Relaxed) { + tokio::select! { + _ = check_head.tick(&clock) => { + dumper.check_head().await?; + } + _ = check_stored_parts.tick(&clock) => { + dumper.check_stored_parts().await; + } + result = dumper.check_parts_upload() => { + result?; + } + } } - let Some(sync_hash) = chain.get_sync_hash(final_hash)? else { - return Ok(None); - }; - let final_block_header = chain.get_block_header(&final_hash)?; - let epoch_id = *final_block_header.epoch_id(); - let epoch_info = epoch_manager.get_epoch_info(&epoch_id)?; - let epoch_height = epoch_info.epoch_height(); - tracing::debug!(target: "state_sync_dump", ?final_hash, ?sync_hash, ?epoch_id, epoch_height, "get_latest_epoch"); + tracing::debug!(target: "state_sync_dump", "Stopped state dump thread"); + Ok(()) +} - Ok(Some(LatestEpochInfo { epoch_id, epoch_height, sync_hash })) +async fn do_state_sync_dump( + clock: Clock, + chain: Chain, + epoch_manager: Arc, + shard_tracker: ShardTracker, + runtime: Arc, + chain_id: String, + external: ExternalConnection, + iteration_delay: Duration, + validator: MutableValidatorSigner, + keep_running: Arc, + future_spawner: Arc, +) { + if let Err(error) = state_sync_dump( + clock, + chain, + epoch_manager, + shard_tracker, + runtime, + chain_id, + external, + iteration_delay, + validator, + keep_running, + future_spawner, + ) + .await + { + tracing::error!(target: "state_sync_dump", ?error, "State dumper failed"); + } }