From 72345f6e250556e7ee1acae308fbbb33918b3019 Mon Sep 17 00:00:00 2001 From: "Tobin C. Harding" Date: Thu, 11 Jan 2024 11:58:54 +1100 Subject: [PATCH] v2: Correctly implement Output::combine As we did for `v2::Input`. The current version of `Output::combine` is copied from the `v0` module and is not correct for `v2` - fix it. --- src/v2/map/output.rs | 76 ++++++++++++++++++++++++++++++++++++++------ src/v2/mod.rs | 10 +++++- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/src/v2/map/output.rs b/src/v2/map/output.rs index db6f22b..af1af9d 100644 --- a/src/v2/map/output.rs +++ b/src/v2/map/output.rs @@ -185,16 +185,28 @@ impl Output { } /// Combines this [`Output`] with `other` `Output` (as described by BIP 174). - pub fn combine(&mut self, other: Self) { - self.bip32_derivations.extend(other.bip32_derivations); - self.proprietaries.extend(other.proprietaries); - self.unknowns.extend(other.unknowns); - self.tap_key_origins.extend(other.tap_key_origins); - - combine!(redeem_script, self, other); - combine!(witness_script, self, other); - combine!(tap_internal_key, self, other); - combine!(tap_tree, self, other); + pub fn combine(&mut self, other: Self) -> Result<(), CombineError> { + if self.amount != other.amount { + return Err(CombineError::AmountMismatch { this: self.amount, that: other.amount }); + } + + if self.script_pubkey != other.script_pubkey { + return Err(CombineError::ScriptPubkeyMismatch { + this: self.script_pubkey.clone(), + that: other.script_pubkey, + }); + } + + combine_option!(redeem_script, self, other); + combine_option!(witness_script, self, other); + combine_map!(bip32_derivations, self, other); + combine_option!(tap_internal_key, self, other); + combine_option!(tap_tree, self, other); + combine_map!(tap_key_origins, self, other); + combine_map!(proprietaries, self, other); + combine_map!(unknowns, self, other); + + Ok(()) } } @@ -345,6 +357,50 @@ impl From for InsertPairError { fn from(e: serialize::Error) -> Self { Self::Deser(e) } } +/// Error combining two output maps. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum CombineError { + /// The amounts are not the same. + AmountMismatch { + /// Attempted to combine a PBST with `this` previous txid. + this: Amount, + /// Into a PBST with `that` previous txid. + that: Amount, + }, + /// The script_pubkeys are not the same. + ScriptPubkeyMismatch { + /// Attempted to combine a PBST with `this` script_pubkey. + this: ScriptBuf, + /// Into a PBST with `that` script_pubkey. + that: ScriptBuf, + }, +} + +impl fmt::Display for CombineError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use CombineError::*; + + match *self { + AmountMismatch { ref this, ref that } => + write!(f, "combine two PSBTs with different amounts: {} {}", this, that), + ScriptPubkeyMismatch { ref this, ref that } => + write!(f, "combine two PSBTs with different script_pubkeys: {:x} {:x}", this, that), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for CombineError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + use CombineError::*; + + match *self { + AmountMismatch { .. } | ScriptPubkeyMismatch { .. } => None, + } + } +} + #[cfg(test)] #[cfg(feature = "std")] mod tests { diff --git a/src/v2/mod.rs b/src/v2/mod.rs index f1c4931..56d11fd 100644 --- a/src/v2/mod.rs +++ b/src/v2/mod.rs @@ -660,7 +660,7 @@ impl Psbt { } for (self_output, other_output) in self.outputs.iter_mut().zip(other.outputs.into_iter()) { - self_output.combine(other_output); + self_output.combine(other_output)?; } Ok(self) @@ -1262,6 +1262,8 @@ pub enum CombineError { Global(global::CombineError), /// Error while combining the input maps. Input(input::CombineError), + /// Error while combining the output maps. + Output(output::CombineError), } impl fmt::Display for CombineError { @@ -1271,6 +1273,7 @@ impl fmt::Display for CombineError { match *self { Global(ref e) => write_err!(f, "error while combining the global maps"; e), Input(ref e) => write_err!(f, "error while combining the input maps"; e), + Output(ref e) => write_err!(f, "error while combining the output maps"; e), } } } @@ -1283,6 +1286,7 @@ impl std::error::Error for CombineError { match *self { Global(ref e) => Some(e), Input(ref e) => Some(e), + Output(ref e) => Some(e), } } } @@ -1294,3 +1298,7 @@ impl From for CombineError { impl From for CombineError { fn from(e: input::CombineError) -> Self { Self::Input(e) } } + +impl From for CombineError { + fn from(e: output::CombineError) -> Self { Self::Output(e) } +}