Skip to content

Commit

Permalink
chore: Write test cases for invalid and bad headers (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique authored Oct 2, 2023
1 parent 8aa8bd7 commit 4620849
Show file tree
Hide file tree
Showing 7 changed files with 312 additions and 45 deletions.
103 changes: 102 additions & 1 deletion node/src/exchange/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ mod tests {
use celestia_proto::p2p::pb::StatusCode;
use celestia_types::consts::HASH_SIZE;
use celestia_types::hash::Hash;
use celestia_types::test_utils::ExtendedHeaderGenerator;
use celestia_types::test_utils::{invalidate, unverify, ExtendedHeaderGenerator};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};

Expand Down Expand Up @@ -612,6 +612,58 @@ mod tests {
));
}

#[tokio::test]
async fn respond_with_invalid_header() {
let peer_tracker = peer_tracker_with_n_peers(15);
let mut mock_req = MockReq::new();
let mut handler = ExchangeClientHandler::<MockReq>::new(peer_tracker);

let (tx, rx) = oneshot::channel();

handler.on_send_request(&mut mock_req, HeaderRequest::with_origin(5, 1), tx);

// Exchange client must return a validated header.
let mut gen = ExtendedHeaderGenerator::new_from_height(5);
let mut invalid_header5 = gen.next();
invalidate(&mut invalid_header5);

mock_req.send_n_responses(&mut handler, 1, vec![invalid_header5.to_header_response()]);

assert!(matches!(
rx.await,
Ok(Err(P2pError::Exchange(ExchangeError::InvalidResponse)))
));
}

#[tokio::test]
async fn respond_with_allowed_bad_header() {
let peer_tracker = peer_tracker_with_n_peers(15);
let mut mock_req = MockReq::new();
let mut handler = ExchangeClientHandler::<MockReq>::new(peer_tracker);

let (tx, rx) = oneshot::channel();

handler.on_send_request(&mut mock_req, HeaderRequest::with_origin(5, 2), tx);

let mut gen = ExtendedHeaderGenerator::new_from_height(5);

// Exchange client must not verify the headers, this is done only
// in `get_verified_headers_range` which is used later on in `Syncer`.
let mut expected_headers = gen.next_many(2);
unverify(&mut expected_headers[1]);

let expected = expected_headers
.iter()
.map(|header| header.to_header_response())
.collect::<Vec<_>>();

mock_req.send_n_responses(&mut handler, 1, expected);

let result = rx.await.unwrap().unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result, expected_headers);
}

#[tokio::test]
async fn invalid_requests() {
let peer_tracker = peer_tracker_with_n_peers(15);
Expand Down Expand Up @@ -830,6 +882,55 @@ mod tests {
assert_eq!(result[0], expected_header);
}

#[tokio::test]
async fn head_request_responds_with_invalid_headers() {
let peer_tracker = peer_tracker_with_n_peers(15);
let mut mock_req = MockReq::new();
let mut handler = ExchangeClientHandler::<MockReq>::new(peer_tracker);

let (tx, rx) = oneshot::channel();

handler.on_send_request(&mut mock_req, HeaderRequest::with_origin(0, 1), tx);

let mut gen = ExtendedHeaderGenerator::new_from_height(5);
let header5 = gen.next();

let mut invalid_header5 = gen.another_of(&header5);
invalidate(&mut invalid_header5);

let expected_header = header5;
let expected = expected_header.to_header_response();

mock_req.send_n_responses(&mut handler, 9, vec![invalid_header5.to_header_response()]);
mock_req.send_n_responses(&mut handler, 1, vec![expected]);

let result = rx.await.unwrap().unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0], expected_header);
}

#[tokio::test]
async fn head_request_responds_only_with_invalid_headers() {
let peer_tracker = peer_tracker_with_n_peers(15);
let mut mock_req = MockReq::new();
let mut handler = ExchangeClientHandler::<MockReq>::new(peer_tracker);

let (tx, rx) = oneshot::channel();

handler.on_send_request(&mut mock_req, HeaderRequest::with_origin(0, 1), tx);

let mut gen = ExtendedHeaderGenerator::new_from_height(5);
let mut invalid_header5 = gen.next();
invalidate(&mut invalid_header5);

mock_req.send_n_responses(&mut handler, 10, vec![invalid_header5.to_header_response()]);

assert!(matches!(
rx.await,
Ok(Err(P2pError::Exchange(ExchangeError::HeaderNotFound)))
));
}

#[tokio::test]
async fn head_request_responds_with_only_failures() {
let peer_tracker = peer_tracker_with_n_peers(15);
Expand Down
12 changes: 2 additions & 10 deletions node/src/exchange/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,8 @@ impl HeaderResponseExt for HeaderResponse {
match self.status_code() {
StatusCode::Invalid => Err(ExchangeError::InvalidResponse),
StatusCode::NotFound => Err(ExchangeError::HeaderNotFound),
StatusCode::Ok => {
let header = ExtendedHeader::decode(&self.body[..])
.map_err(|_| ExchangeError::InvalidResponse)?;

header
.validate()
.map_err(|_| ExchangeError::InvalidResponse)?;

Ok(header)
}
StatusCode::Ok => ExtendedHeader::decode_and_validate(&self.body[..])
.map_err(|_| ExchangeError::InvalidResponse),
}
}
}
Expand Down
10 changes: 2 additions & 8 deletions node/src/p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use libp2p::{
},
Multiaddr, PeerId, TransportError,
};
use tendermint_proto::Protobuf;
use tokio::select;
use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info, instrument, trace, warn};
Expand Down Expand Up @@ -549,16 +548,11 @@ where

#[instrument(skip_all)]
fn on_header_sub_message(&mut self, data: &[u8]) {
let Ok(header) = ExtendedHeader::decode(data) else {
trace!("Malformed header from header-sub");
let Ok(header) = ExtendedHeader::decode_and_validate(data) else {
trace!("Malformed or invalid header from header-sub");
return;
};

if let Err(e) = header.validate() {
trace!("Invalid header from header-sub ({e})");
return;
}

debug!("New header from header-sub ({header})");
// TODO: inform syncer about it
//
Expand Down
2 changes: 2 additions & 0 deletions types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ tendermint-proto = { workspace = true }
thiserror = "1.0.40"

[dev-dependencies]
ed25519-consensus = "2.1.0"
proptest = "1.2.0"
rand = "0.8.5"
serde_json = "1.0.97"

[features]
Expand Down
72 changes: 65 additions & 7 deletions types/src/extended_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ impl Display for ExtendedHeader {
}

impl ExtendedHeader {
/// Decode and then validate header.
pub fn decode_and_validate(bytes: &[u8]) -> Result<Self> {
let header = ExtendedHeader::decode(bytes)?;
header.validate()?;
Ok(header)
}

pub fn chain_id(&self) -> &Id {
&self.header.chain_id
}
Expand All @@ -59,6 +66,7 @@ impl ExtendedHeader {
.unwrap_or_default()
}

/// Validate header.
pub fn validate(&self) -> Result<()> {
self.header.validate_basic()?;
self.commit.validate_basic()?;
Expand Down Expand Up @@ -110,6 +118,7 @@ impl ExtendedHeader {
Ok(())
}

/// Verify an untrusted header.
pub fn verify(&self, untrusted: &ExtendedHeader) -> Result<()> {
if untrusted.height() <= self.height() {
bail_verification!(
Expand Down Expand Up @@ -177,6 +186,13 @@ impl ExtendedHeader {
Ok(())
}

/// Verify a chain of untrusted headers.
///
/// # Note
///
/// This method does not do validation for optimization purposes.
/// Validation should be done from before and ideally with
/// [`ExtendedHeader::decode_and_validate`].
pub fn verify_range(&self, untrusted: &[ExtendedHeader]) -> Result<()> {
let mut trusted = self;

Expand All @@ -192,14 +208,20 @@ impl ExtendedHeader {
);
}

untrusted.validate()?;
trusted.verify(untrusted)?;
trusted = untrusted;
}

Ok(())
}

/// Verify a chain of untrusted headers make sure that are adjacent to `self`.
///
/// # Note
///
/// This method does not do validation for optimization purposes.
/// Validation should be done from before and ideally with
/// [`ExtendedHeader::decode_and_validate`].
pub fn verify_adjacent_range(&self, untrusted: &[ExtendedHeader]) -> Result<()> {
if untrusted.is_empty() {
return Ok(());
Expand Down Expand Up @@ -255,9 +277,19 @@ impl From<ExtendedHeader> for RawExtendedHeader {
}
}

/// Convenient utility for validating multiple headers.
pub fn validate_headers(headers: &[ExtendedHeader]) -> Result<()> {
for header in headers {
header.validate()?;
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{invalidate, unverify};

fn sample_eh_chain_1_block_1() -> ExtendedHeader {
let s = include_str!("../test_data/chain1/extended_header_block_1.json");
Expand Down Expand Up @@ -427,6 +459,25 @@ mod tests {
eh_block_1.verify(&eh_block_27).unwrap_err();
}

#[test]
fn validate_multiple_headers() {
let mut eh_chain = sample_eh_chain_3_block_1_to_256();

validate_headers(&eh_chain).unwrap();

// Non-continuous headers are allowed
eh_chain.remove(2);
validate_headers(&eh_chain).unwrap();

// Bad headers are allowed
unverify(&mut eh_chain[2]);
validate_headers(&eh_chain).unwrap();

// Invalid header are not allowed
invalidate(&mut eh_chain[3]);
validate_headers(&eh_chain).unwrap_err();
}

#[test]
fn verify_range() {
let eh_chain = sample_eh_chain_3_block_1_to_256();
Expand Down Expand Up @@ -460,20 +511,27 @@ mod tests {
}

#[test]
fn verify_range_invalid_header_in_middle() {
fn verify_range_bad_header_in_middle() {
let eh_chain = sample_eh_chain_3_block_1_to_256();

let mut headers = eh_chain[10..15].to_vec();

headers[2].header.time = headers[2]
.header
.time
.checked_add(Duration::from_millis(1))
.unwrap();
unverify(&mut headers[2]);

eh_chain[0].verify_range(&headers).unwrap_err();
}

#[test]
fn verify_range_allow_invalid_header_in_middle() {
let eh_chain = sample_eh_chain_3_block_1_to_256();

let mut headers = eh_chain[10..15].to_vec();

invalidate(&mut headers[2]);

eh_chain[0].verify_range(&headers).unwrap();
}

#[test]
fn verify_adjacent_range() {
let eh_chain = sample_eh_chain_3_block_1_to_256();
Expand Down
2 changes: 1 addition & 1 deletion types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub(crate) mod serializers;
mod share;
pub mod state;
mod sync;
#[cfg(feature = "test-utils")]
#[cfg(any(test, feature = "test-utils"))]
pub mod test_utils;
pub mod trust_level;
mod validate;
Expand Down
Loading

0 comments on commit 4620849

Please sign in to comment.