diff --git a/zcash_client_backend/src/data_api/wallet/input_selection.rs b/zcash_client_backend/src/data_api/wallet/input_selection.rs
index 104c78a6c1..d129c5d95d 100644
--- a/zcash_client_backend/src/data_api/wallet/input_selection.rs
+++ b/zcash_client_backend/src/data_api/wallet/input_selection.rs
@@ -394,6 +394,58 @@ where
// of funds selected is strictly increasing. The loop will either return a successful
// result or the wallet will eventually run out of funds to select.
loop {
+ let sapling_input_total = shielded_inputs
+ .iter()
+ .filter(|i| matches!(i.note(), Note::Sapling(_)))
+ .map(|i| i.note().value())
+ .sum::>()
+ .ok_or(BalanceError::Overflow)?;
+
+ #[cfg(feature = "orchard")]
+ let orchard_input_total = shielded_inputs
+ .iter()
+ .filter(|i| matches!(i.note(), Note::Orchard(_)))
+ .map(|i| i.note().value())
+ .sum:: >()
+ .ok_or(BalanceError::Overflow)?;
+ #[cfg(not(feature = "orchard"))]
+ let orchard_input_total = NonNegativeAmount::ZERO;
+
+ let sapling_inputs =
+ if sapling_outputs.is_empty() && orchard_input_total >= amount_required {
+ // Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is
+ // fully covered by Orchard inputs.
+ #[cfg(feature = "orchard")]
+ shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_)));
+ vec![]
+ } else {
+ shielded_inputs
+ .iter()
+ .filter_map(|i| match i.note() {
+ Note::Sapling(n) => Some((*i.internal_note_id(), n.value())),
+ #[cfg(feature = "orchard")]
+ Note::Orchard(_) => None,
+ })
+ .collect::>()
+ };
+
+ #[cfg(feature = "orchard")]
+ let orchard_inputs =
+ if orchard_outputs.is_empty() && sapling_input_total >= amount_required {
+ // Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is
+ // fully covered by Sapling inputs.
+ shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_)));
+ vec![]
+ } else {
+ shielded_inputs
+ .iter()
+ .filter_map(|i| match i.note() {
+ Note::Sapling(_) => None,
+ Note::Orchard(n) => Some((*i.internal_note_id(), n.value())),
+ })
+ .collect::>()
+ };
+
let balance = self.change_strategy.compute_balance(
params,
target_height,
@@ -401,27 +453,13 @@ where
&transparent_outputs,
&(
::sapling::builder::BundleType::DEFAULT,
- &shielded_inputs
- .iter()
- .cloned()
- .filter_map(|i| match i.note() {
- Note::Sapling(n) => Some((*i.internal_note_id(), n.value())),
- #[cfg(feature = "orchard")]
- Note::Orchard(_) => None,
- })
- .collect::>()[..],
+ &sapling_inputs[..],
&sapling_outputs[..],
),
#[cfg(feature = "orchard")]
&(
::orchard::builder::BundleType::DEFAULT,
- &shielded_inputs
- .iter()
- .filter_map(|i| match i.note() {
- Note::Sapling(_) => None,
- Note::Orchard(n) => Some((*i.internal_note_id(), n.value())),
- })
- .collect::>()[..],
+ &orchard_inputs[..],
&orchard_outputs[..],
),
&self.dust_output_policy,
diff --git a/zcash_client_backend/src/wallet.rs b/zcash_client_backend/src/wallet.rs
index 5e65ddd0dc..f8d8f4ffb6 100644
--- a/zcash_client_backend/src/wallet.rs
+++ b/zcash_client_backend/src/wallet.rs
@@ -432,7 +432,7 @@ impl ReceivedNote {
}
}
-impl<'a, NoteRef> sapling_fees::InputView for (NoteRef, sapling::value::NoteValue) {
+impl sapling_fees::InputView for (NoteRef, sapling::value::NoteValue) {
fn note_id(&self) -> &NoteRef {
&self.0
}
@@ -460,7 +460,7 @@ impl sapling_fees::InputView for ReceivedNote orchard_fees::InputView for (NoteRef, orchard::value::NoteValue) {
+impl orchard_fees::InputView for (NoteRef, orchard::value::NoteValue) {
fn note_id(&self) -> &NoteRef {
&self.0
}
diff --git a/zcash_client_sqlite/src/testing/pool.rs b/zcash_client_sqlite/src/testing/pool.rs
index 637d4d5ba5..7a5aa38101 100644
--- a/zcash_client_sqlite/src/testing/pool.rs
+++ b/zcash_client_sqlite/src/testing/pool.rs
@@ -1406,7 +1406,7 @@ pub(crate) fn checkpoint_gaps() {
}
#[cfg(feature = "orchard")]
-pub(crate) fn cross_pool_exchange() {
+pub(crate) fn pool_crossing_required() {
let mut st = TestBuilder::new()
.with_block_cache()
.with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5))
@@ -1419,12 +1419,11 @@ pub(crate) fn cross_pool_exchange() {
+ let mut st = TestBuilder::new()
+ .with_block_cache()
+ .with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5))
+ .build();
+
+ let (account, usk, birthday) = st.test_account().unwrap();
+
+ let p0_fvk = P0::test_account_fvk(&st);
+
+ let p1_fvk = P1::test_account_fvk(&st);
+ let p1_to = P1::fvk_default_address(&p1_fvk);
+
+ let note_value = NonNegativeAmount::const_from_u64(350000);
+ st.generate_next_block(&p0_fvk, AddressType::DefaultExternal, note_value);
+ st.generate_next_block(&p1_fvk, AddressType::DefaultExternal, note_value);
+ st.scan_cached_blocks(birthday.height(), 2);
+
+ let initial_balance = (note_value * 2).unwrap();
+ assert_eq!(st.get_total_balance(account), initial_balance);
+ assert_eq!(st.get_spendable_balance(account, 1), initial_balance);
+
+ let transfer_amount = NonNegativeAmount::const_from_u64(200000);
+ let p0_to_p1 = zip321::TransactionRequest::new(vec![Payment {
+ recipient_address: p1_to,
+ amount: transfer_amount,
+ memo: None,
+ label: None,
+ message: None,
+ other_params: vec![],
+ }])
+ .unwrap();
+
+ let fee_rule = StandardFeeRule::Zip317;
+ let input_selector = GreedyInputSelector::new(
+ // We set the default change output pool to P0, because we want to verify later that
+ // change is actually sent to P1 (as the transaction is fully fundable from P1).
+ standard::SingleOutputChangeStrategy::new(fee_rule, None, P0::SHIELDED_PROTOCOL),
+ DustOutputPolicy::default(),
+ );
+ let proposal0 = st
+ .propose_transfer(
+ account,
+ &input_selector,
+ p0_to_p1,
+ NonZeroU32::new(1).unwrap(),
+ )
+ .unwrap();
+
+ let _min_target_height = proposal0.min_target_height();
+ assert_eq!(proposal0.steps().len(), 1);
+ let step0 = &proposal0.steps().head;
+
+ // We expect 2 logical actions, since either pool can pay the full balance required
+ // and note selection should choose the fully-private path.
+ let expected_fee = NonNegativeAmount::const_from_u64(10000);
+ assert_eq!(step0.balance().fee_required(), expected_fee);
+
+ let expected_change = (note_value - transfer_amount - expected_fee).unwrap();
+ let proposed_change = step0.balance().proposed_change();
+ assert_eq!(proposed_change.len(), 1);
+ let change_output = proposed_change.get(0).unwrap();
+ // Since this is a cross-pool transfer, change will be sent to the preferred pool.
+ assert_eq!(change_output.output_pool(), P1::SHIELDED_PROTOCOL);
+ assert_eq!(change_output.value(), expected_change);
+
+ let create_proposed_result =
+ st.create_proposed_transactions::(&usk, OvkPolicy::Sender, &proposal0);
+ assert_matches!(&create_proposed_result, Ok(txids) if txids.len() == 1);
+
+ let (h, _) = st.generate_next_block_including(create_proposed_result.unwrap()[0]);
+ st.scan_cached_blocks(h, 1);
+
+ assert_eq!(
+ st.get_total_balance(account),
+ (initial_balance - expected_fee).unwrap()
+ );
+ assert_eq!(
+ st.get_spendable_balance(account, 1),
+ (initial_balance - expected_fee).unwrap()
+ );
+}
+
pub(crate) fn valid_chain_states() {
let mut st = TestBuilder::new()
.with_block_cache()
diff --git a/zcash_client_sqlite/src/wallet/orchard.rs b/zcash_client_sqlite/src/wallet/orchard.rs
index 3b3859077c..7143869a28 100644
--- a/zcash_client_sqlite/src/wallet/orchard.rs
+++ b/zcash_client_sqlite/src/wallet/orchard.rs
@@ -604,9 +604,16 @@ pub(crate) mod tests {
}
#[test]
- fn cross_pool_exchange() {
+ fn pool_crossing_required() {
use crate::wallet::sapling::tests::SaplingPoolTester;
- testing::pool::cross_pool_exchange::()
+ testing::pool::pool_crossing_required::()
+ }
+
+ #[test]
+ fn fully_funded_fully_private() {
+ use crate::wallet::sapling::tests::SaplingPoolTester;
+
+ testing::pool::fully_funded_fully_private::()
}
}
diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs
index 38f076619f..38072f9132 100644
--- a/zcash_client_sqlite/src/wallet/sapling.rs
+++ b/zcash_client_sqlite/src/wallet/sapling.rs
@@ -607,9 +607,17 @@ pub(crate) mod tests {
#[test]
#[cfg(feature = "orchard")]
- fn cross_pool_exchange() {
+ fn pool_crossing_required() {
use crate::wallet::orchard::tests::OrchardPoolTester;
- testing::pool::cross_pool_exchange::()
+ testing::pool::pool_crossing_required::()
+ }
+
+ #[test]
+ #[cfg(feature = "orchard")]
+ fn fully_funded_fully_private() {
+ use crate::wallet::orchard::tests::OrchardPoolTester;
+
+ testing::pool::fully_funded_fully_private::()
}
}