Skip to content

Commit

Permalink
zcash_client_backend: Add WalletRead::get_account
Browse files Browse the repository at this point in the history
  • Loading branch information
str4d committed Mar 13, 2024
1 parent 6182071 commit 1e13785
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 13 deletions.
1 change: 1 addition & 0 deletions zcash_client_backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ and this library adheres to Rust's notion of
- Arguments to `ScannedBlock::from_parts` have changed.
- Changes to the `WalletRead` trait:
- Added `Account` associated type.
- Added `get_account` method.
- Added `get_derived_account` method.
- `get_account_for_ufvk` now returns `Self::Account` instead of a bare
`AccountId`.
Expand Down
13 changes: 13 additions & 0 deletions zcash_client_backend/src/data_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ pub trait WalletRead {
/// Returns a vector with the IDs of all accounts known to this wallet.
fn get_account_ids(&self) -> Result<Vec<Self::AccountId>, Self::Error>;

/// Returns the account corresponding to the given ID, if any.
fn get_account(
&self,
account_id: Self::AccountId,
) -> Result<Option<Self::Account>, Self::Error>;

/// Returns the account corresponding to a given [`HdSeedFingerprint`] and
/// [`zip32::AccountId`], if any.
fn get_derived_account(
Expand Down Expand Up @@ -1553,6 +1559,13 @@ pub mod testing {
Ok(Vec::new())
}

fn get_account(

Check warning on line 1562 in zcash_client_backend/src/data_api.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_backend/src/data_api.rs#L1562

Added line #L1562 was not covered by tests
&self,
_account_id: Self::AccountId,
) -> Result<Option<Self::Account>, Self::Error> {
Ok(None)

Check warning on line 1566 in zcash_client_backend/src/data_api.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_backend/src/data_api.rs#L1566

Added line #L1566 was not covered by tests
}

fn get_derived_account(

Check warning on line 1569 in zcash_client_backend/src/data_api.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_backend/src/data_api.rs#L1569

Added line #L1569 was not covered by tests
&self,
_seed: &HdSeedFingerprint,
Expand Down
9 changes: 8 additions & 1 deletion zcash_client_sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
wallet::get_account_ids(self.conn.borrow())
}

fn get_account(
&self,
account_id: Self::AccountId,
) -> Result<Option<Self::Account>, Self::Error> {
wallet::get_account(self.conn.borrow(), &self.params, account_id)
}

fn get_derived_account(

Check warning on line 303 in zcash_client_sqlite/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/lib.rs#L303

Added line #L303 was not covered by tests
&self,
seed: &HdSeedFingerprint,
Expand All @@ -306,7 +313,7 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
account_id: Self::AccountId,
seed: &SecretVec<u8>,
) -> Result<bool, Self::Error> {
if let Some(account) = wallet::get_account(self, account_id)? {
if let Some(account) = self.get_account(account_id)? {
if let AccountKind::Derived {

Check warning on line 317 in zcash_client_sqlite/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/lib.rs#L317

Added line #L317 was not covered by tests
seed_fingerprint,
account_index,
Expand Down
16 changes: 8 additions & 8 deletions zcash_client_sqlite/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
use incrementalmerkletree::Retention;
use rusqlite::{self, named_params, params, OptionalExtension};
use shardtree::{error::ShardTreeError, store::ShardStore, ShardTree};
use std::borrow::Borrow;

use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::{self, Cursor};
Expand Down Expand Up @@ -1494,11 +1494,12 @@ pub(crate) fn block_height_extrema(
})
}

pub(crate) fn get_account<C: Borrow<rusqlite::Connection>, P: Parameters>(
db: &WalletDb<C, P>,
pub(crate) fn get_account<P: Parameters>(
conn: &rusqlite::Connection,
params: &P,
account_id: AccountId,
) -> Result<Option<Account>, SqliteClientError> {
let mut sql = db.conn.borrow().prepare_cached(
let mut sql = conn.prepare_cached(
r#"
SELECT account_type, hd_seed_fingerprint, hd_account_index, ufvk, uivk
FROM accounts
Expand All @@ -1519,15 +1520,15 @@ pub(crate) fn get_account<C: Borrow<rusqlite::Connection>, P: Parameters>(
let ufvk_str: Option<String> = row.get("ufvk")?;
let viewing_key = if let Some(ufvk_str) = ufvk_str {
ViewingKey::Full(Box::new(

Check warning on line 1522 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1522

Added line #L1522 was not covered by tests
UnifiedFullViewingKey::decode(&db.params, &ufvk_str[..])
UnifiedFullViewingKey::decode(params, &ufvk_str[..])
.map_err(SqliteClientError::BadAccountData)?,
))
} else {
let uivk_str: String = row.get("uivk")?;
let (network, uivk) = Uivk::decode(&uivk_str).map_err(|e| {
SqliteClientError::CorruptedData(format!("Failure to decode UIVK: {e}"))

Check warning on line 1529 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1527-L1529

Added lines #L1527 - L1529 were not covered by tests
})?;
if network != db.params.network_type() {
if network != params.network_type() {
return Err(SqliteClientError::CorruptedData(
"UIVK network type does not match wallet network type".to_string(),

Check warning on line 1533 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1531-L1533

Added lines #L1531 - L1533 were not covered by tests
));
Expand Down Expand Up @@ -2704,7 +2705,6 @@ mod tests {

use crate::{
testing::{AddressType, BlockCache, TestBuilder, TestState},
wallet::get_account,
AccountId,
};

Expand Down Expand Up @@ -2852,7 +2852,7 @@ mod tests {
.with_test_account(AccountBirthday::from_sapling_activation)
.build();
let account_id = st.test_account().unwrap().0;
let account_parameters = get_account(st.wallet(), account_id).unwrap().unwrap();
let account_parameters = st.wallet().get_account(account_id).unwrap().unwrap();

let expected_account_index = zip32::AccountId::try_from(0).unwrap();
assert_matches!(
Expand Down
6 changes: 2 additions & 4 deletions zcash_client_sqlite/src/wallet/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1282,9 +1282,7 @@ mod tests {
#[test]
#[cfg(feature = "transparent-inputs")]
fn account_produces_expected_ua_sequence() {
use zcash_client_backend::data_api::{AccountBirthday, AccountKind};

use crate::wallet::get_account;
use zcash_client_backend::data_api::{AccountBirthday, AccountKind, WalletRead};

let network = Network::MainNetwork;
let data_file = NamedTempFile::new().unwrap();
Expand All @@ -1300,7 +1298,7 @@ mod tests {
.create_account(&Secret::new(seed.to_vec()), birthday)
.unwrap();
assert_matches!(
get_account(&db_data, account_id),
db_data.get_account(account_id),
Ok(Some(account)) if matches!(
account.kind,
AccountKind::Derived{account_index, ..} if account_index == zip32::AccountId::ZERO,
Expand Down

0 comments on commit 1e13785

Please sign in to comment.