Skip to content

Commit

Permalink
Merge pull request #17 from tcharding/02-01-clean-up
Browse files Browse the repository at this point in the history
Do a bunch of clean ups
  • Loading branch information
tcharding authored Feb 1, 2024
2 parents f7c9cc8 + 66abe2c commit 4ffdadc
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 121 deletions.
9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ version = "0.1.0"
authors = ["Tobin C. Harding <[email protected]>"]
license = "CC0-1.0"
repository = "https://github.com/tcharding/rust-psbt/"
description = "Partially signed Bitcoin Transaction, v0 and v1"
description = "Partially Signed Bitcoin Transaction, v0 and v2"
categories = ["cryptography::cryptocurrencies"]
keywords = [ "crypto", "bitcoin" ]
readme = "../README.md"
readme = "README.md"
edition = "2021"
rust-version = "1.56.1"
exclude = ["tests", "contrib"]
Expand All @@ -27,11 +27,10 @@ miniscript-std = ["std", "miniscript/std"]
miniscript-no-std = ["no-std", "miniscript/no-std"]

[dependencies]
bitcoin = { version = "0.31.0", default-features = false, features = [] }
bitcoin = { version = "0.31.0", default-features = false }

# Currenty miniscript only works in with "std" enabled.
# Do not use this feature, use "miniscript-std" or "miniscript-no-std" instead.
miniscript = { version = "11.0.0", default-features = false, optional = true }

# Do NOT use this as a feature! Use the `serde` feature instead.
actual-serde = { package = "serde", version = "1.0.103", default-features = false, features = [ "derive", "alloc" ], optional = true }
# There is no reason to use this dependency directly, it is activated by the "no-std" feature.
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ use std::io;
#[cfg(not(feature = "std"))]
use core2::io;

use crate::version::Version;

#[rustfmt::skip] // Keep pubic re-exports separate
#[doc(inline)]
pub use crate::{
sighash_type::PsbtSighashType,
sighash_type::{PsbtSighashType, InvalidSighashTypeError},
version::Version,
};

/// PSBT version 0 - the original PSBT version.
Expand Down
64 changes: 62 additions & 2 deletions src/sighash_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,15 @@ impl std::error::Error for SighashTypeParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
}

// TODO: Remove this error after issue resolves.
// https://github.com/rust-bitcoin/rust-bitcoin/issues/2423
/// Integer is not a consensus valid sighash type.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum InvalidSighashTypeError {
/// TODO:
/// The real invalid sighash type error.
Bitcoin(sighash::InvalidSighashTypeError),
/// TODO:
/// Hack required because of non_exhaustive on the real error.
Invalid(u32),
}

Expand Down Expand Up @@ -149,3 +151,61 @@ impl std::error::Error for InvalidSighashTypeError {
impl From<sighash::InvalidSighashTypeError> for InvalidSighashTypeError {
fn from(e: sighash::InvalidSighashTypeError) -> Self { Self::Bitcoin(e) }
}

#[cfg(test)]
mod tests {
use core::str::FromStr;

use super::*;
use crate::sighash_type::InvalidSighashTypeError;

#[test]
fn psbt_sighash_type_ecdsa() {
for ecdsa in &[
EcdsaSighashType::All,
EcdsaSighashType::None,
EcdsaSighashType::Single,
EcdsaSighashType::AllPlusAnyoneCanPay,
EcdsaSighashType::NonePlusAnyoneCanPay,
EcdsaSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*ecdsa);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa);
}
}

#[test]
fn psbt_sighash_type_taproot() {
for tap in &[
TapSighashType::Default,
TapSighashType::All,
TapSighashType::None,
TapSighashType::Single,
TapSighashType::AllPlusAnyoneCanPay,
TapSighashType::NonePlusAnyoneCanPay,
TapSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*tap);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.taproot_hash_ty().unwrap(), *tap);
}
}

#[test]
fn psbt_sighash_type_notstd() {
let nonstd = 0xdddddddd;
let sighash = PsbtSighashType { inner: nonstd };
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();

assert_eq!(back, sighash);
// TODO: Add this assertion once we remove InvalidSighashTypeError
// assert_eq!(back.ecdsa_hash_ty(), Err(NonStandardSighashTypeError(nonstd)));
assert_eq!(back.taproot_hash_ty(), Err(InvalidSighashTypeError::Invalid(nonstd)));
}
}
15 changes: 7 additions & 8 deletions src/v0/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ use crate::v0::Psbt;
///
/// This error is returned when deserializing a complete PSBT, not for deserializing parts
/// of it or individual data types.
// TODO: This can change to `serialize::Error` if we rename `serialize::Error` to `serialize::Error`.
#[derive(Debug)]
#[non_exhaustive]
pub enum DeserializePsbtError {
pub enum DeserializeError {
/// Invalid magic bytes, expected the ASCII for "psbt" serialized in most significant byte order.
// TODO: Consider adding the invalid bytes.
InvalidMagic,
Expand All @@ -36,28 +35,28 @@ pub enum DeserializePsbtError {
UnsignedTxChecks(UnsignedTxChecksError),
}

impl fmt::Display for DeserializePsbtError {
impl fmt::Display for DeserializeError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { todo!() }
}

#[cfg(feature = "std")]
impl std::error::Error for DeserializePsbtError {
impl std::error::Error for DeserializeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { todo!() }
}

impl From<global::DecodeError> for DeserializePsbtError {
impl From<global::DecodeError> for DeserializeError {
fn from(e: global::DecodeError) -> Self { Self::DecodeGlobal(e) }
}

impl From<input::DecodeError> for DeserializePsbtError {
impl From<input::DecodeError> for DeserializeError {
fn from(e: input::DecodeError) -> Self { Self::DecodeInput(e) }
}

impl From<output::DecodeError> for DeserializePsbtError {
impl From<output::DecodeError> for DeserializeError {
fn from(e: output::DecodeError) -> Self { Self::DecodeOutput(e) }
}

impl From<UnsignedTxChecksError> for DeserializePsbtError {
impl From<UnsignedTxChecksError> for DeserializeError {
fn from(e: UnsignedTxChecksError) -> Self { Self::UnsignedTxChecks(e) }
}

Expand Down
1 change: 1 addition & 0 deletions src/v0/miniscript/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::fmt;

use crate::bitcoin::{sighash, ScriptBuf};
use crate::miniscript::{self, descriptor, interpreter};
use crate::prelude::*;
#[cfg(doc)]
use crate::v0::Psbt;

Expand Down
15 changes: 7 additions & 8 deletions src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::v0::map::{global, Map};

#[rustfmt::skip] // Keep pubic re-exports separate
pub use self::{
error::{IndexOutOfBoundsError, SignerChecksError, SignError, UnsignedTxChecksError, DeserializePsbtError},
error::{IndexOutOfBoundsError, SignerChecksError, SignError, UnsignedTxChecksError, DeserializeError},
map::{Input, Output, Global},
};

Expand Down Expand Up @@ -84,20 +84,19 @@ impl Psbt {
buf
}

// TODO: Change this to use DeserializePsbtError (although that name is shit) same as v2.
/// Deserialize a value from raw binary data.
pub fn deserialize(bytes: &[u8]) -> Result<Self, DeserializePsbtError> {
pub fn deserialize(bytes: &[u8]) -> Result<Self, DeserializeError> {
const MAGIC_BYTES: &[u8] = b"psbt";
if bytes.get(0..MAGIC_BYTES.len()) != Some(MAGIC_BYTES) {
return Err(DeserializePsbtError::InvalidMagic);
return Err(DeserializeError::InvalidMagic);
}

const PSBT_SERPARATOR: u8 = 0xff_u8;
if bytes.get(MAGIC_BYTES.len()) != Some(&PSBT_SERPARATOR) {
return Err(DeserializePsbtError::InvalidSeparator);
return Err(DeserializeError::InvalidSeparator);
}

let mut d = bytes.get(5..).ok_or(DeserializePsbtError::NoMorePairs)?;
let mut d = bytes.get(5..).ok_or(DeserializeError::NoMorePairs)?;

let global = Global::decode(&mut d)?;
global.unsigned_tx_checks()?;
Expand Down Expand Up @@ -565,7 +564,7 @@ mod display_from_str {
#[non_exhaustive]
pub enum PsbtParseError {
/// Error in internal PSBT data structure.
PsbtEncoding(DeserializePsbtError),
PsbtEncoding(DeserializeError),
/// Error in PSBT Base64 encoding.
Base64Encoding(bitcoin::base64::DecodeError),
}
Expand Down Expand Up @@ -808,7 +807,7 @@ mod tests {
use crate::{io, raw, V0};

#[track_caller]
pub fn hex_psbt(s: &str) -> Result<Psbt, DeserializePsbtError> {
pub fn hex_psbt(s: &str) -> Result<Psbt, DeserializeError> {
let r: Result<Vec<u8>, bitcoin::hex::HexToBytesError> = Vec::from_hex(s);
match r {
Err(_e) => panic!("unable to parse hex string {}", s),
Expand Down
13 changes: 6 additions & 7 deletions src/v2/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ use crate::v2::map::{global, input, output};
///
/// This error is returned when deserializing a complete PSBT, not for deserializing parts
/// of it or individual data types.
// TODO: This can change to `serialize::Error` if we rename `serialize::Error` to `serialize::Error`.
#[derive(Debug)]
#[non_exhaustive]
pub enum DeserializePsbtError {
pub enum DeserializeError {
/// Invalid magic bytes, expected the ASCII for "psbt" serialized in most significant byte order.
// TODO: Consider adding the invalid bytes.
InvalidMagic,
Expand All @@ -34,24 +33,24 @@ pub enum DeserializePsbtError {
DecodeOutput(output::DecodeError),
}

impl fmt::Display for DeserializePsbtError {
impl fmt::Display for DeserializeError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { todo!() }
}

#[cfg(feature = "std")]
impl std::error::Error for DeserializePsbtError {
impl std::error::Error for DeserializeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { todo!() }
}

impl From<global::DecodeError> for DeserializePsbtError {
impl From<global::DecodeError> for DeserializeError {
fn from(e: global::DecodeError) -> Self { Self::DecodeGlobal(e) }
}

impl From<input::DecodeError> for DeserializePsbtError {
impl From<input::DecodeError> for DeserializeError {
fn from(e: input::DecodeError) -> Self { Self::DecodeInput(e) }
}

impl From<output::DecodeError> for DeserializePsbtError {
impl From<output::DecodeError> for DeserializeError {
fn from(e: output::DecodeError) -> Self { Self::DecodeOutput(e) }
}

Expand Down
51 changes: 40 additions & 11 deletions src/v2/extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
//!
//! It is only possible to extract a transaction from a PSBT _after_ it has been finalized. However
//! the Extractor role may be fulfilled by a separate entity to the Finalizer hence this is a
//! separate module and does not require `rust-miniscript`.
//! separate module and does not require the "miniscript" feature be enabled.
//!
//! [BIP-174]: <https://github.com/bitcoin/bips/blob/master/bip-0174.mediawiki>
use core::fmt;

use bitcoin::{FeeRate, Transaction};
use bitcoin::{FeeRate, Transaction, Txid};

use crate::error::{write_err, FeeError};
use crate::v2::{DetermineLockTimeError, Psbt};
Expand All @@ -30,13 +30,19 @@ impl Extractor {
/// Creates an `Extractor`.
///
/// An extractor can only accept a PSBT that has been finalized.
pub fn new(psbt: Psbt) -> Result<Self, PsbtNotFinalizedError> {
pub fn new(psbt: Psbt) -> Result<Self, Error> {
if psbt.inputs.iter().any(|input| !input.is_finalized()) {
return Err(PsbtNotFinalizedError);
return Err(Error::PsbtNotFinalized);
}
let _ = psbt.determine_lock_time()?;

Ok(Self(psbt))
}

/// Returns this PSBT's unique identification.
pub fn id(&self) -> Txid {
self.0.id().expect("Extractor guarantees lock time can be determined")
}
}

impl Extractor {
Expand Down Expand Up @@ -118,19 +124,42 @@ impl Extractor {
}
}

/// Attempted to extract tx from an unfinalized PSBT.
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct PsbtNotFinalizedError;
/// Error constructing a [`Finalizer`].
#[derive(Debug)]
pub enum Error {
/// Attempted to extract tx from an unfinalized PSBT.
PsbtNotFinalized,
/// Finalizer must be able to determine the lock time.
DetermineLockTime(DetermineLockTimeError),
}

impl fmt::Display for PsbtNotFinalizedError {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "attempted to extract tx from an unfinalized PSBT")
use Error::*;

match *self {
PsbtNotFinalized => write!(f, "attempted to extract tx from an unfinalized PSBT"),
DetermineLockTime(ref e) =>
write_err!(f, "extractor must be able to determine the lock time"; e),
}
}
}

#[cfg(feature = "std")]
impl std::error::Error for PsbtNotFinalizedError {}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use Error::*;

match *self {
DetermineLockTime(ref e) => Some(e),
PsbtNotFinalized => None,
}
}
}

impl From<DetermineLockTimeError> for Error {
fn from(e: DetermineLockTimeError) -> Self { Self::DetermineLockTime(e) }
}

/// Error caused by fee calculation when extracting a [`Transaction`] from a PSBT.
#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
10 changes: 4 additions & 6 deletions src/v2/map/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use core::convert::TryFrom;
use core::fmt;

use bitcoin::bip32::{ChildNumber, DerivationPath, Fingerprint, KeySource, Xpub};
use bitcoin::consensus::encode::MAX_VEC_SIZE;
use bitcoin::consensus::{encode as consensus, Decodable};
use bitcoin::locktime::absolute;
use bitcoin::{bip32, transaction, Transaction, VarInt};
Expand Down Expand Up @@ -127,16 +126,15 @@ impl Global {
self.tx_modifiable_flags & OUTPUTS_MODIFIABLE > 0
}

// TODO: Use this function?
// TODO: Investigate if we should be using this function?
#[allow(dead_code)]
pub(crate) fn has_sighash_single(&self) -> bool {
self.tx_modifiable_flags & SIGHASH_SINGLE > 0
}

pub(crate) fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, DecodeError> {
// TODO(tobin): Work out why do we do this take, its not done in input or output modules.
let mut r = r.take(MAX_VEC_SIZE as u64);

// TODO: Consider adding protection against memory exhaustion here by defining a maximum
// PBST size and using `take` as we do in rust-bitcoin consensus decoding.
let mut version: Option<Version> = None;
let mut tx_version: Option<transaction::Version> = None;
let mut fallback_lock_time: Option<absolute::LockTime> = None;
Expand Down Expand Up @@ -312,7 +310,7 @@ impl Global {
};

loop {
match raw::Pair::decode(&mut r) {
match raw::Pair::decode(r) {
Ok(pair) => insert_pair(pair)?,
Err(serialize::Error::NoMorePairs) => break,
Err(e) => return Err(DecodeError::DeserPair(e)),
Expand Down
Loading

0 comments on commit 4ffdadc

Please sign in to comment.