diff --git a/node/src/database/rocksdb.rs b/node/src/database/rocksdb.rs index bcf03bc761..fbf4a5208a 100644 --- a/node/src/database/rocksdb.rs +++ b/node/src/database/rocksdb.rs @@ -65,73 +65,13 @@ impl Backend { let inner = self.rocksdb.transaction_opt(&write_options, &tx_options); - // Borrow column families - let ledger_cf = self - .rocksdb - .cf_handle(CF_LEDGER_HEADER) - .expect("ledger_header column family must exist"); - - let ledger_txs_cf = self - .rocksdb - .cf_handle(CF_LEDGER_TXS) - .expect("CF_LEDGER_TXS column family must exist"); - - let ledger_faults_cf = self - .rocksdb - .cf_handle(CF_LEDGER_FAULTS) - .expect("CF_LEDGER_FAULTS column family must exist"); - - let candidates_cf = self - .rocksdb - .cf_handle(CF_CANDIDATES) - .expect("candidates column family must exist"); - - let candidates_height_cf = self - .rocksdb - .cf_handle(CF_CANDIDATES_HEIGHT) - .expect("candidates column family must exist"); - - let mempool_cf = self - .rocksdb - .cf_handle(CF_MEMPOOL) - .expect("mempool column family must exist"); - - let nullifiers_cf = self - .rocksdb - .cf_handle(CF_MEMPOOL_NULLIFIERS) - .expect("CF_MEMPOOL_NULLIFIERS column family must exist"); - - let fees_cf = self - .rocksdb - .cf_handle(CF_MEMPOOL_FEES) - .expect("CF_MEMPOOL_FEES column family must exist"); - - let ledger_height_cf = self - .rocksdb - .cf_handle(CF_LEDGER_HEIGHT) - .expect("CF_LEDGER_HEIGHT column family must exist"); - - let metadata_cf = self - .rocksdb - .cf_handle(CF_METADATA) - .expect("CF_METADATA column family must exist"); - let snapshot = self.rocksdb.snapshot(); DBTransaction::<'_, OptimisticTransactionDB> { inner, - candidates_cf, - candidates_height_cf, - ledger_cf, - ledger_txs_cf, - ledger_faults_cf, - mempool_cf, - nullifiers_cf, - fees_cf, - ledger_height_cf, - metadata_cf, snapshot, cumulative_inner_size: RefCell::new(0), + rocksdb: &self.rocksdb, } } } @@ -251,31 +191,21 @@ impl DB for Backend { } pub struct DBTransaction<'db, DB: DBAccess> { + rocksdb: &'db Arc, inner: rocksdb_lib::Transaction<'db, DB>, /// cumulative size of transaction footprint cumulative_inner_size: RefCell, - - // TODO: pack all column families into a single array - // Candidates column family - candidates_cf: &'db ColumnFamily, - candidates_height_cf: &'db ColumnFamily, - - // Ledger column families - ledger_cf: &'db ColumnFamily, - ledger_faults_cf: &'db ColumnFamily, - ledger_txs_cf: &'db ColumnFamily, - ledger_height_cf: &'db ColumnFamily, - - // Mempool column families - mempool_cf: &'db ColumnFamily, - nullifiers_cf: &'db ColumnFamily, - fees_cf: &'db ColumnFamily, - - metadata_cf: &'db ColumnFamily, - snapshot: SnapshotWithThreadMode<'db, DB>, } +impl<'db, DB: DBAccess> DBTransaction<'db, DB> { + fn cf(&self, cf_name: &str) -> &ColumnFamily { + self.rocksdb + .cf_handle(cf_name) + .unwrap_or_else(|| panic!("cf not found: {}", cf_name)) + } +} + impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { fn store_block( &self, @@ -288,7 +218,7 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { // It consists of one record per block - Header record // It also includes single record to store metadata - Register record { - let cf = self.ledger_cf; + let cf = self.cf(CF_LEDGER_HEADER); let mut buf = vec![]; LightBlock { @@ -311,7 +241,7 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { // COLUMN FAMILY: CF_LEDGER_TXS { - let cf = self.ledger_txs_cf; + let cf = self.cf(CF_LEDGER_TXS); // store all block transactions for tx in txs { @@ -323,7 +253,7 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { // COLUMN FAMILY: CF_LEDGER_FAULTS { - let cf = self.ledger_faults_cf; + let cf = self.cf(CF_LEDGER_FAULTS); // store all block faults for f in faults { @@ -375,39 +305,47 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { buf.write_all(hash)?; label.write(&mut buf)?; - self.put_cf(self.ledger_height_cf, height.to_le_bytes(), buf)?; + let ledger_height_cf = self.cf(CF_LEDGER_HEIGHT); + + self.put_cf(ledger_height_cf, height.to_le_bytes(), buf)?; Ok(()) } fn delete_block(&self, b: &ledger::Block) -> Result<()> { - self.inner.delete_cf( - self.ledger_height_cf, - b.header().height.to_le_bytes(), - )?; + let ledger_height_cf = self.cf(CF_LEDGER_HEIGHT); + self.inner + .delete_cf(ledger_height_cf, b.header().height.to_le_bytes())?; + + let ledger_txs_cf = self.cf(CF_LEDGER_TXS); for tx in b.txs() { - self.inner.delete_cf(self.ledger_txs_cf, tx.id())?; + self.inner.delete_cf(ledger_txs_cf, tx.id())?; } + + let ledger_faults_cf: &ColumnFamily = self.cf(CF_LEDGER_FAULTS); for f in b.faults() { - self.inner.delete_cf(self.ledger_faults_cf, f.hash())?; + self.inner.delete_cf(ledger_faults_cf, f.hash())?; } - self.inner.delete_cf(self.ledger_cf, b.header().hash)?; + let ledger_cf = self.cf(CF_LEDGER_HEADER); + self.inner.delete_cf(ledger_cf, b.header().hash)?; Ok(()) } fn get_block_exists(&self, hash: &[u8]) -> Result { - Ok(self.snapshot.get_cf(self.ledger_cf, hash)?.is_some()) + let ledger_cf = self.cf(CF_LEDGER_HEADER); + Ok(self.snapshot.get_cf(ledger_cf, hash)?.is_some()) } fn fetch_faults(&self, faults_ids: &[[u8; 32]]) -> Result> { if faults_ids.is_empty() { return Ok(vec![]); } + let ledger_faults_cf = self.cf(CF_LEDGER_FAULTS); let ids = faults_ids .iter() - .map(|id| (self.ledger_faults_cf, id)) + .map(|id| (ledger_faults_cf, id)) .collect::>(); // Retrieve all faults ID with single call @@ -424,7 +362,11 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { } fn fetch_block(&self, hash: &[u8]) -> Result> { - match self.snapshot.get_cf(self.ledger_cf, hash)? { + let ledger_cf = self.cf(CF_LEDGER_HEADER); + let ledger_txs_cf = self.cf(CF_LEDGER_TXS); + let ledger_faults_cf = self.cf(CF_LEDGER_FAULTS); + + match self.snapshot.get_cf(ledger_cf, hash)? { Some(blob) => { let record = LightBlock::read(&mut &blob[..])?; @@ -433,7 +375,7 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { record .transactions_ids .iter() - .map(|id| (self.ledger_txs_cf, id)) + .map(|id| (ledger_txs_cf, id)) .collect::>(), ); @@ -450,7 +392,7 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { record .faults_ids .iter() - .map(|id| (self.ledger_faults_cf, id)) + .map(|id| (ledger_faults_cf, id)) .collect::>(), ); let mut faults = vec![]; @@ -470,7 +412,9 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { } fn fetch_light_block(&self, hash: &[u8]) -> Result> { - match self.snapshot.get_cf(self.ledger_cf, hash)? { + let ledger_cf = self.cf(CF_LEDGER_HEADER); + + match self.snapshot.get_cf(ledger_cf, hash)? { Some(blob) => { let record = LightBlock::read(&mut &blob[..])?; Ok(Some(record)) @@ -480,7 +424,9 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { } fn fetch_block_header(&self, hash: &[u8]) -> Result> { - match self.snapshot.get_cf(self.ledger_cf, hash)? { + let ledger_cf = self.cf(CF_LEDGER_HEADER); + + match self.snapshot.get_cf(ledger_cf, hash)? { Some(blob) => { let record = Header::read(&mut &blob[..])?; Ok(Some(record)) @@ -493,9 +439,11 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { &self, height: u64, ) -> Result> { + let ledger_height_cf = self.cf(CF_LEDGER_HEIGHT); + Ok(self .snapshot - .get_cf(self.ledger_height_cf, height.to_le_bytes())? + .get_cf(ledger_height_cf, height.to_le_bytes())? .map(|h| { const LEN: usize = 32; let mut hash = [0u8; LEN]; @@ -508,9 +456,10 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { &self, tx_id: &[u8], ) -> Result> { + let ledger_txs_cf = self.cf(CF_LEDGER_TXS); let tx = self .snapshot - .get_cf(self.ledger_txs_cf, tx_id)? + .get_cf(ledger_txs_cf, tx_id)? .map(|blob| ledger::SpentTransaction::read(&mut &blob[..])) .transpose()?; @@ -523,7 +472,8 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { /// This is a convenience method that checks if a transaction exists in the /// ledger without unmarshalling the transaction fn get_ledger_tx_exists(&self, tx_id: &[u8]) -> Result { - Ok(self.snapshot.get_cf(self.ledger_txs_cf, tx_id)?.is_some()) + let ledger_txs_cf = self.cf(CF_LEDGER_TXS); + Ok(self.snapshot.get_cf(ledger_txs_cf, tx_id)?.is_some()) } fn fetch_block_by_height( @@ -542,10 +492,11 @@ impl<'db, DB: DBAccess> Ledger for DBTransaction<'db, DB> { &self, height: u64, ) -> Result> { + let ledger_height_cf = self.cf(CF_LEDGER_HEIGHT); const HASH_LEN: usize = 32; Ok(self .snapshot - .get_cf(self.ledger_height_cf, height.to_le_bytes())? + .get_cf(ledger_height_cf, height.to_le_bytes())? .map(|h| { let mut hash = [0u8; HASH_LEN]; hash.copy_from_slice(&h.as_slice()[0..HASH_LEN]); @@ -570,15 +521,18 @@ impl<'db, DB: DBAccess> Candidate for DBTransaction<'db, DB> { /// Returns `Ok(())` if the block is successfully stored, or an error if the /// operation fails. fn store_candidate_block(&self, b: ledger::Block) -> Result<()> { + let candidates_cf = self.cf(CF_CANDIDATES); + let candidates_height_cf = self.cf(CF_CANDIDATES_HEIGHT); + let mut serialized = vec![]; b.write(&mut serialized)?; self.inner - .put_cf(self.candidates_cf, b.header().hash, serialized)?; + .put_cf(candidates_cf, b.header().hash, serialized)?; let key = serialize_key(b.header().height, b.header().hash)?; self.inner - .put_cf(self.candidates_height_cf, key, b.header().hash)?; + .put_cf(candidates_height_cf, key, b.header().hash)?; Ok(()) } @@ -597,7 +551,9 @@ impl<'db, DB: DBAccess> Candidate for DBTransaction<'db, DB> { &self, hash: &[u8], ) -> Result> { - if let Some(blob) = self.snapshot.get_cf(self.candidates_cf, hash)? { + let candidates_cf = self.cf(CF_CANDIDATES); + + if let Some(blob) = self.snapshot.get_cf(candidates_cf, hash)? { let b = ledger::Block::read(&mut &blob[..])?; return Ok(Some(b)); } @@ -620,15 +576,18 @@ impl<'db, DB: DBAccess> Candidate for DBTransaction<'db, DB> { where F: FnOnce(u64) -> bool + std::marker::Copy, { + let candidates_cf = self.cf(CF_CANDIDATES); + let candidates_height_cf = self.cf(CF_CANDIDATES_HEIGHT); + let iter = self .inner - .iterator_cf(self.candidates_height_cf, IteratorMode::Start); + .iterator_cf(candidates_height_cf, IteratorMode::Start); for (key, hash) in iter.map(Result::unwrap) { let (height, _) = deserialize_key(&mut &key.to_vec()[..])?; if closure(height) { - self.inner.delete_cf(self.candidates_cf, hash)?; - self.inner.delete_cf(self.candidates_height_cf, key)?; + self.inner.delete_cf(candidates_cf, hash)?; + self.inner.delete_cf(candidates_height_cf, key)?; } } @@ -636,9 +595,10 @@ impl<'db, DB: DBAccess> Candidate for DBTransaction<'db, DB> { } fn count(&self) -> usize { + let candidates_height_cf = self.cf(CF_CANDIDATES_HEIGHT); let iter = self .inner - .iterator_cf(self.candidates_height_cf, IteratorMode::Start); + .iterator_cf(candidates_height_cf, IteratorMode::Start); iter.count() } @@ -657,12 +617,14 @@ impl<'db, DB: DBAccess> Candidate for DBTransaction<'db, DB> { impl<'db, DB: DBAccess> Persist for DBTransaction<'db, DB> { /// Deletes all items from both CF_LEDGER and CF_CANDIDATES column families fn clear_database(&self) -> Result<()> { + let ledger_cf = self.cf(CF_LEDGER_HEADER); + // Create an iterator over the column family CF_LEDGER - let iter = self.inner.iterator_cf(self.ledger_cf, IteratorMode::Start); + let iter = self.inner.iterator_cf(ledger_cf, IteratorMode::Start); // Iterate through the CF_LEDGER column family and delete all items for (key, _) in iter.map(Result::unwrap) { - self.inner.delete_cf(self.ledger_cf, key)?; + self.inner.delete_cf(ledger_cf, key)?; } self.clear_candidates()?; @@ -680,32 +642,33 @@ impl<'db, DB: DBAccess> Persist for DBTransaction<'db, DB> { impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { fn add_tx(&self, tx: &ledger::Transaction) -> Result<()> { + let mempool_cf = self.cf(CF_MEMPOOL); + let nullifiers_cf = self.cf(CF_MEMPOOL_NULLIFIERS); + let fees_cf = self.cf(CF_MEMPOOL_FEES); + // Map Hash to serialized transaction let mut tx_data = vec![]; tx.write(&mut tx_data)?; let hash = tx.id(); - self.put_cf(self.mempool_cf, hash, tx_data)?; + self.put_cf(mempool_cf, hash, tx_data)?; // Add Secondary indexes // // Nullifiers for n in tx.inner.nullifiers() { let key = n.to_bytes(); - self.put_cf(self.nullifiers_cf, key, hash)?; + self.put_cf(nullifiers_cf, key, hash)?; } // Map Fee_Hash to Null to facilitate sort-by-fee - self.put_cf( - self.fees_cf, - serialize_key(tx.gas_price(), hash)?, - vec![0], - )?; + self.put_cf(fees_cf, serialize_key(tx.gas_price(), hash)?, vec![0])?; Ok(()) } fn get_tx(&self, hash: [u8; 32]) -> Result> { - let data = self.inner.get_cf(self.mempool_cf, hash)?; + let mempool_cf = self.cf(CF_MEMPOOL); + let data = self.inner.get_cf(mempool_cf, hash)?; match data { // None has a meaning key not found @@ -717,28 +680,31 @@ impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { } fn get_tx_exists(&self, h: [u8; 32]) -> Result { - Ok(self.snapshot.get_cf(self.mempool_cf, h)?.is_some()) + let mempool_cf = self.cf(CF_MEMPOOL); + Ok(self.snapshot.get_cf(mempool_cf, h)?.is_some()) } fn delete_tx(&self, h: [u8; 32]) -> Result { + let mempool_cf = self.cf(CF_MEMPOOL); + let nullifiers_cf = self.cf(CF_MEMPOOL_NULLIFIERS); + let fees_cf = self.cf(CF_MEMPOOL_FEES); + let tx = self.get_tx(h)?; if let Some(tx) = tx { let hash = tx.id(); - self.inner.delete_cf(self.mempool_cf, hash)?; + self.inner.delete_cf(mempool_cf, hash)?; // Delete Secondary indexes // Delete Nullifiers for n in tx.inner.nullifiers() { let key = n.to_bytes(); - self.inner.delete_cf(self.nullifiers_cf, key)?; + self.inner.delete_cf(nullifiers_cf, key)?; } // Delete Fee_Hash - self.inner.delete_cf( - self.fees_cf, - serialize_key(tx.gas_price(), hash)?, - )?; + self.inner + .delete_cf(fees_cf, serialize_key(tx.gas_price(), hash)?)?; return Ok(true); } @@ -747,8 +713,9 @@ impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { } fn get_txs_by_nullifiers(&self, n: &[[u8; 32]]) -> HashSet<[u8; 32]> { + let nullifiers_cf = self.cf(CF_MEMPOOL_NULLIFIERS); n.iter() - .filter_map(|n| match self.snapshot.get_cf(self.nullifiers_cf, n) { + .filter_map(|n| match self.snapshot.get_cf(nullifiers_cf, n) { Ok(Some(tx_id)) => tx_id.try_into().ok(), _ => None, }) @@ -758,7 +725,8 @@ impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { fn get_txs_sorted_by_fee( &self, ) -> Result + '_>> { - let iter = MemPoolIterator::new(&self.inner, self.fees_cf, self); + let fees_cf = self.cf(CF_MEMPOOL_FEES); + let iter = MemPoolIterator::new(&self.inner, fees_cf, self); Ok(Box::new(iter)) } @@ -766,13 +734,15 @@ impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { fn get_txs_ids_sorted_by_fee( &self, ) -> Result + '_>> { - let iter = MemPoolFeeIterator::new(&self.inner, self.fees_cf); + let fees_cf = self.cf(CF_MEMPOOL_FEES); + let iter = MemPoolFeeIterator::new(&self.inner, fees_cf); Ok(Box::new(iter)) } fn get_txs_ids(&self) -> Result> { - let mut iter = self.inner.raw_iterator_cf(self.fees_cf); + let fees_cf = self.cf(CF_MEMPOOL_FEES); + let mut iter = self.inner.raw_iterator_cf(fees_cf); iter.seek_to_last(); let mut txs_list = vec![]; @@ -792,8 +762,10 @@ impl<'db, DB: DBAccess> Mempool for DBTransaction<'db, DB> { } fn txs_count(&self) -> usize { + let mempool_cf = self.cf(CF_MEMPOOL); + self.inner - .iterator_cf(self.mempool_cf, IteratorMode::Start) + .iterator_cf(mempool_cf, IteratorMode::Start) .count() } } @@ -856,13 +828,12 @@ impl Iterator for MemPoolFeeIterator<'_, DB> { impl<'db, DB: DBAccess> std::fmt::Debug for DBTransaction<'db, DB> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let ledger_cf = self.cf(CF_LEDGER_HEADER); // Print ledger blocks - let iter = self.inner.iterator_cf(self.ledger_cf, IteratorMode::Start); + let iter = self.inner.iterator_cf(ledger_cf, IteratorMode::Start); iter.map(Result::unwrap).try_for_each(|(hash, _)| { - if let Ok(Some(blob)) = - self.snapshot.get_cf(self.ledger_cf, &hash[..]) - { + if let Ok(Some(blob)) = self.snapshot.get_cf(ledger_cf, &hash[..]) { let b = ledger::Block::read(&mut &blob[..]).unwrap_or_default(); writeln!(f, "ledger_block [{}]: {:#?}", b.header().height, b) } else { @@ -871,14 +842,13 @@ impl<'db, DB: DBAccess> std::fmt::Debug for DBTransaction<'db, DB> { })?; // Print candidate blocks - let iter = self - .inner - .iterator_cf(self.candidates_cf, IteratorMode::Start); + let candidates_cf = self.cf(CF_CANDIDATES); + let iter = self.inner.iterator_cf(candidates_cf, IteratorMode::Start); let results: std::fmt::Result = iter.map(Result::unwrap).try_for_each(|(hash, _)| { if let Ok(Some(blob)) = - self.snapshot.get_cf(self.candidates_cf, &hash[..]) + self.snapshot.get_cf(candidates_cf, &hash[..]) { let b = ledger::Block::read(&mut &blob[..]).unwrap_or_default(); @@ -899,12 +869,16 @@ impl<'db, DB: DBAccess> std::fmt::Debug for DBTransaction<'db, DB> { impl<'db, DB: DBAccess> Metadata for DBTransaction<'db, DB> { fn op_write>(&self, key: &[u8], value: T) -> Result<()> { - self.put_cf(self.metadata_cf, key, value)?; + let metadata_cf = self.cf(CF_METADATA); + + self.put_cf(metadata_cf, key, value)?; Ok(()) } fn op_read(&self, key: &[u8]) -> Result>> { - self.inner.get_cf(self.metadata_cf, key).map_err(Into::into) + let metadata_cf = self.cf(CF_METADATA); + + self.inner.get_cf(metadata_cf, key).map_err(Into::into) } }