From 1e1378596acb31b13d7a0de1e46b26d88a4d2f33 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 13 Mar 2024 20:21:33 +0000 Subject: [PATCH] zcash_client_backend: Add `WalletRead::get_account` --- zcash_client_backend/CHANGELOG.md | 1 + zcash_client_backend/src/data_api.rs | 13 +++++++++++++ zcash_client_sqlite/src/lib.rs | 9 ++++++++- zcash_client_sqlite/src/wallet.rs | 16 ++++++++-------- zcash_client_sqlite/src/wallet/init.rs | 6 ++---- 5 files changed, 32 insertions(+), 13 deletions(-) diff --git a/zcash_client_backend/CHANGELOG.md b/zcash_client_backend/CHANGELOG.md index 25c58a9296..be1ba5e8e2 100644 --- a/zcash_client_backend/CHANGELOG.md +++ b/zcash_client_backend/CHANGELOG.md @@ -54,6 +54,7 @@ and this library adheres to Rust's notion of - Arguments to `ScannedBlock::from_parts` have changed. - Changes to the `WalletRead` trait: - Added `Account` associated type. + - Added `get_account` method. - Added `get_derived_account` method. - `get_account_for_ufvk` now returns `Self::Account` instead of a bare `AccountId`. diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index d20ffe131d..1c0ff1745b 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -574,6 +574,12 @@ pub trait WalletRead { /// Returns a vector with the IDs of all accounts known to this wallet. fn get_account_ids(&self) -> Result, Self::Error>; + /// Returns the account corresponding to the given ID, if any. + fn get_account( + &self, + account_id: Self::AccountId, + ) -> Result, Self::Error>; + /// Returns the account corresponding to a given [`HdSeedFingerprint`] and /// [`zip32::AccountId`], if any. fn get_derived_account( @@ -1553,6 +1559,13 @@ pub mod testing { Ok(Vec::new()) } + fn get_account( + &self, + _account_id: Self::AccountId, + ) -> Result, Self::Error> { + Ok(None) + } + fn get_derived_account( &self, _seed: &HdSeedFingerprint, diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index 2286291b58..f4137d8be5 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -293,6 +293,13 @@ impl, P: consensus::Parameters> WalletRead for W wallet::get_account_ids(self.conn.borrow()) } + fn get_account( + &self, + account_id: Self::AccountId, + ) -> Result, Self::Error> { + wallet::get_account(self.conn.borrow(), &self.params, account_id) + } + fn get_derived_account( &self, seed: &HdSeedFingerprint, @@ -306,7 +313,7 @@ impl, P: consensus::Parameters> WalletRead for W account_id: Self::AccountId, seed: &SecretVec, ) -> Result { - if let Some(account) = wallet::get_account(self, account_id)? { + if let Some(account) = self.get_account(account_id)? { if let AccountKind::Derived { seed_fingerprint, account_index, diff --git a/zcash_client_sqlite/src/wallet.rs b/zcash_client_sqlite/src/wallet.rs index c7e569b110..c0850c414d 100644 --- a/zcash_client_sqlite/src/wallet.rs +++ b/zcash_client_sqlite/src/wallet.rs @@ -67,7 +67,7 @@ use incrementalmerkletree::Retention; use rusqlite::{self, named_params, params, OptionalExtension}; use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree}; -use std::borrow::Borrow; + use std::collections::HashMap; use std::convert::TryFrom; use std::io::{self, Cursor}; @@ -1494,11 +1494,12 @@ pub(crate) fn block_height_extrema( }) } -pub(crate) fn get_account, P: Parameters>( - db: &WalletDb, +pub(crate) fn get_account( + conn: &rusqlite::Connection, + params: &P, account_id: AccountId, ) -> Result, SqliteClientError> { - let mut sql = db.conn.borrow().prepare_cached( + let mut sql = conn.prepare_cached( r#" SELECT account_type, hd_seed_fingerprint, hd_account_index, ufvk, uivk FROM accounts @@ -1519,7 +1520,7 @@ pub(crate) fn get_account, P: Parameters>( let ufvk_str: Option = row.get("ufvk")?; let viewing_key = if let Some(ufvk_str) = ufvk_str { ViewingKey::Full(Box::new( - UnifiedFullViewingKey::decode(&db.params, &ufvk_str[..]) + UnifiedFullViewingKey::decode(params, &ufvk_str[..]) .map_err(SqliteClientError::BadAccountData)?, )) } else { @@ -1527,7 +1528,7 @@ pub(crate) fn get_account, P: Parameters>( let (network, uivk) = Uivk::decode(&uivk_str).map_err(|e| { SqliteClientError::CorruptedData(format!("Failure to decode UIVK: {e}")) })?; - if network != db.params.network_type() { + if network != params.network_type() { return Err(SqliteClientError::CorruptedData( "UIVK network type does not match wallet network type".to_string(), )); @@ -2704,7 +2705,6 @@ mod tests { use crate::{ testing::{AddressType, BlockCache, TestBuilder, TestState}, - wallet::get_account, AccountId, }; @@ -2852,7 +2852,7 @@ mod tests { .with_test_account(AccountBirthday::from_sapling_activation) .build(); let account_id = st.test_account().unwrap().0; - let account_parameters = get_account(st.wallet(), account_id).unwrap().unwrap(); + let account_parameters = st.wallet().get_account(account_id).unwrap().unwrap(); let expected_account_index = zip32::AccountId::try_from(0).unwrap(); assert_matches!( diff --git a/zcash_client_sqlite/src/wallet/init.rs b/zcash_client_sqlite/src/wallet/init.rs index 3f93447c8f..e74259b562 100644 --- a/zcash_client_sqlite/src/wallet/init.rs +++ b/zcash_client_sqlite/src/wallet/init.rs @@ -1282,9 +1282,7 @@ mod tests { #[test] #[cfg(feature = "transparent-inputs")] fn account_produces_expected_ua_sequence() { - use zcash_client_backend::data_api::{AccountBirthday, AccountKind}; - - use crate::wallet::get_account; + use zcash_client_backend::data_api::{AccountBirthday, AccountKind, WalletRead}; let network = Network::MainNetwork; let data_file = NamedTempFile::new().unwrap(); @@ -1300,7 +1298,7 @@ mod tests { .create_account(&Secret::new(seed.to_vec()), birthday) .unwrap(); assert_matches!( - get_account(&db_data, account_id), + db_data.get_account(account_id), Ok(Some(account)) if matches!( account.kind, AccountKind::Derived{account_index, ..} if account_index == zip32::AccountId::ZERO,