diff --git a/zcash_client_backend/src/data_api/wallet/input_selection.rs b/zcash_client_backend/src/data_api/wallet/input_selection.rs index 104c78a6c1..d129c5d95d 100644 --- a/zcash_client_backend/src/data_api/wallet/input_selection.rs +++ b/zcash_client_backend/src/data_api/wallet/input_selection.rs @@ -394,6 +394,58 @@ where // of funds selected is strictly increasing. The loop will either return a successful // result or the wallet will eventually run out of funds to select. loop { + let sapling_input_total = shielded_inputs + .iter() + .filter(|i| matches!(i.note(), Note::Sapling(_))) + .map(|i| i.note().value()) + .sum::>() + .ok_or(BalanceError::Overflow)?; + + #[cfg(feature = "orchard")] + let orchard_input_total = shielded_inputs + .iter() + .filter(|i| matches!(i.note(), Note::Orchard(_))) + .map(|i| i.note().value()) + .sum::>() + .ok_or(BalanceError::Overflow)?; + #[cfg(not(feature = "orchard"))] + let orchard_input_total = NonNegativeAmount::ZERO; + + let sapling_inputs = + if sapling_outputs.is_empty() && orchard_input_total >= amount_required { + // Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is + // fully covered by Orchard inputs. + #[cfg(feature = "orchard")] + shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_))); + vec![] + } else { + shielded_inputs + .iter() + .filter_map(|i| match i.note() { + Note::Sapling(n) => Some((*i.internal_note_id(), n.value())), + #[cfg(feature = "orchard")] + Note::Orchard(_) => None, + }) + .collect::>() + }; + + #[cfg(feature = "orchard")] + let orchard_inputs = + if orchard_outputs.is_empty() && sapling_input_total >= amount_required { + // Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is + // fully covered by Sapling inputs. + shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_))); + vec![] + } else { + shielded_inputs + .iter() + .filter_map(|i| match i.note() { + Note::Sapling(_) => None, + Note::Orchard(n) => Some((*i.internal_note_id(), n.value())), + }) + .collect::>() + }; + let balance = self.change_strategy.compute_balance( params, target_height, @@ -401,27 +453,13 @@ where &transparent_outputs, &( ::sapling::builder::BundleType::DEFAULT, - &shielded_inputs - .iter() - .cloned() - .filter_map(|i| match i.note() { - Note::Sapling(n) => Some((*i.internal_note_id(), n.value())), - #[cfg(feature = "orchard")] - Note::Orchard(_) => None, - }) - .collect::>()[..], + &sapling_inputs[..], &sapling_outputs[..], ), #[cfg(feature = "orchard")] &( ::orchard::builder::BundleType::DEFAULT, - &shielded_inputs - .iter() - .filter_map(|i| match i.note() { - Note::Sapling(_) => None, - Note::Orchard(n) => Some((*i.internal_note_id(), n.value())), - }) - .collect::>()[..], + &orchard_inputs[..], &orchard_outputs[..], ), &self.dust_output_policy, diff --git a/zcash_client_backend/src/wallet.rs b/zcash_client_backend/src/wallet.rs index 5e65ddd0dc..f8d8f4ffb6 100644 --- a/zcash_client_backend/src/wallet.rs +++ b/zcash_client_backend/src/wallet.rs @@ -432,7 +432,7 @@ impl ReceivedNote { } } -impl<'a, NoteRef> sapling_fees::InputView for (NoteRef, sapling::value::NoteValue) { +impl sapling_fees::InputView for (NoteRef, sapling::value::NoteValue) { fn note_id(&self) -> &NoteRef { &self.0 } @@ -460,7 +460,7 @@ impl sapling_fees::InputView for ReceivedNote orchard_fees::InputView for (NoteRef, orchard::value::NoteValue) { +impl orchard_fees::InputView for (NoteRef, orchard::value::NoteValue) { fn note_id(&self) -> &NoteRef { &self.0 } diff --git a/zcash_client_sqlite/src/testing/pool.rs b/zcash_client_sqlite/src/testing/pool.rs index 637d4d5ba5..7a5aa38101 100644 --- a/zcash_client_sqlite/src/testing/pool.rs +++ b/zcash_client_sqlite/src/testing/pool.rs @@ -1406,7 +1406,7 @@ pub(crate) fn checkpoint_gaps() { } #[cfg(feature = "orchard")] -pub(crate) fn cross_pool_exchange() { +pub(crate) fn pool_crossing_required() { let mut st = TestBuilder::new() .with_block_cache() .with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5)) @@ -1419,12 +1419,11 @@ pub(crate) fn cross_pool_exchange() { + let mut st = TestBuilder::new() + .with_block_cache() + .with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5)) + .build(); + + let (account, usk, birthday) = st.test_account().unwrap(); + + let p0_fvk = P0::test_account_fvk(&st); + + let p1_fvk = P1::test_account_fvk(&st); + let p1_to = P1::fvk_default_address(&p1_fvk); + + let note_value = NonNegativeAmount::const_from_u64(350000); + st.generate_next_block(&p0_fvk, AddressType::DefaultExternal, note_value); + st.generate_next_block(&p1_fvk, AddressType::DefaultExternal, note_value); + st.scan_cached_blocks(birthday.height(), 2); + + let initial_balance = (note_value * 2).unwrap(); + assert_eq!(st.get_total_balance(account), initial_balance); + assert_eq!(st.get_spendable_balance(account, 1), initial_balance); + + let transfer_amount = NonNegativeAmount::const_from_u64(200000); + let p0_to_p1 = zip321::TransactionRequest::new(vec![Payment { + recipient_address: p1_to, + amount: transfer_amount, + memo: None, + label: None, + message: None, + other_params: vec![], + }]) + .unwrap(); + + let fee_rule = StandardFeeRule::Zip317; + let input_selector = GreedyInputSelector::new( + // We set the default change output pool to P0, because we want to verify later that + // change is actually sent to P1 (as the transaction is fully fundable from P1). + standard::SingleOutputChangeStrategy::new(fee_rule, None, P0::SHIELDED_PROTOCOL), + DustOutputPolicy::default(), + ); + let proposal0 = st + .propose_transfer( + account, + &input_selector, + p0_to_p1, + NonZeroU32::new(1).unwrap(), + ) + .unwrap(); + + let _min_target_height = proposal0.min_target_height(); + assert_eq!(proposal0.steps().len(), 1); + let step0 = &proposal0.steps().head; + + // We expect 2 logical actions, since either pool can pay the full balance required + // and note selection should choose the fully-private path. + let expected_fee = NonNegativeAmount::const_from_u64(10000); + assert_eq!(step0.balance().fee_required(), expected_fee); + + let expected_change = (note_value - transfer_amount - expected_fee).unwrap(); + let proposed_change = step0.balance().proposed_change(); + assert_eq!(proposed_change.len(), 1); + let change_output = proposed_change.get(0).unwrap(); + // Since this is a cross-pool transfer, change will be sent to the preferred pool. + assert_eq!(change_output.output_pool(), P1::SHIELDED_PROTOCOL); + assert_eq!(change_output.value(), expected_change); + + let create_proposed_result = + st.create_proposed_transactions::(&usk, OvkPolicy::Sender, &proposal0); + assert_matches!(&create_proposed_result, Ok(txids) if txids.len() == 1); + + let (h, _) = st.generate_next_block_including(create_proposed_result.unwrap()[0]); + st.scan_cached_blocks(h, 1); + + assert_eq!( + st.get_total_balance(account), + (initial_balance - expected_fee).unwrap() + ); + assert_eq!( + st.get_spendable_balance(account, 1), + (initial_balance - expected_fee).unwrap() + ); +} + pub(crate) fn valid_chain_states() { let mut st = TestBuilder::new() .with_block_cache() diff --git a/zcash_client_sqlite/src/wallet/orchard.rs b/zcash_client_sqlite/src/wallet/orchard.rs index 3b3859077c..7143869a28 100644 --- a/zcash_client_sqlite/src/wallet/orchard.rs +++ b/zcash_client_sqlite/src/wallet/orchard.rs @@ -604,9 +604,16 @@ pub(crate) mod tests { } #[test] - fn cross_pool_exchange() { + fn pool_crossing_required() { use crate::wallet::sapling::tests::SaplingPoolTester; - testing::pool::cross_pool_exchange::() + testing::pool::pool_crossing_required::() + } + + #[test] + fn fully_funded_fully_private() { + use crate::wallet::sapling::tests::SaplingPoolTester; + + testing::pool::fully_funded_fully_private::() } } diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index 38f076619f..38072f9132 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -607,9 +607,17 @@ pub(crate) mod tests { #[test] #[cfg(feature = "orchard")] - fn cross_pool_exchange() { + fn pool_crossing_required() { use crate::wallet::orchard::tests::OrchardPoolTester; - testing::pool::cross_pool_exchange::() + testing::pool::pool_crossing_required::() + } + + #[test] + #[cfg(feature = "orchard")] + fn fully_funded_fully_private() { + use crate::wallet::orchard::tests::OrchardPoolTester; + + testing::pool::fully_funded_fully_private::() } }