Skip to content

Commit

Permalink
refactor: remove Option<bool> decompression args
Browse files Browse the repository at this point in the history
  • Loading branch information
suchapalaver committed Oct 21, 2024
1 parent 8ea5efa commit 2acb88e
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 55 deletions.
4 changes: 2 additions & 2 deletions crates/flat-files-decoder/benches/decoder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
extern crate rand;

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use decoder::handle_file;
use decoder::{handle_file, Decompression};
use std::fs;

const ITERS_PER_FILE: usize = 10;
Expand All @@ -23,7 +23,7 @@ fn bench(c: &mut Criterion) {
}
}

b.iter(|| handle_file(black_box(&path), None, None, None));
b.iter(|| handle_file(black_box(&path), None, None, Decompression::None));
}
});

Expand Down
80 changes: 50 additions & 30 deletions crates/flat-files-decoder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,32 @@ use zstd::stream::decode_all;

const MERGE_BLOCK: usize = 15537393;

#[derive(Clone, Copy, Debug)]
pub enum Decompression {
Zstd,
None,
}

impl From<bool> for Decompression {
fn from(value: bool) -> Self {
if value {
Decompression::Zstd
} else {
Decompression::None
}
}
}

impl From<&str> for Decompression {
fn from(value: &str) -> Self {
match value {
"true" => Decompression::Zstd,
"false" => Decompression::None,
_ => Decompression::None,
}
}
}

pub enum DecodeInput {
Path(String),
Reader(Box<dyn Read>),
Expand All @@ -49,12 +75,12 @@ pub enum DecodeInput {
/// If `None`, decoded blocks are not written to disk.
/// * `headers_dir`: An [`Option<&str>`] specifying the directory containing header files for verification.
/// Must be a directory if provided.
/// * `decompress`: An [`Option<bool>`] specifying if it is necessary to decompress from zstd.
/// * `decompress`: A [`Decompression`] enum specifying if it is necessary to decompress from zstd.
pub fn decode_flat_files(
input: String,
output: Option<&str>,
headers_dir: Option<&str>,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<Block>, DecodeError> {
let metadata = fs::metadata(&input).map_err(DecodeError::IoError)?;

Expand All @@ -75,7 +101,7 @@ fn decode_flat_files_dir(
input: &str,
output: Option<&str>,
headers_dir: Option<&str>,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<Block>, DecodeError> {
let paths = fs::read_dir(input).map_err(DecodeError::IoError)?;

Expand Down Expand Up @@ -120,22 +146,24 @@ fn decode_flat_files_dir(
/// If `None`, decoded blocks are not written to disk.
/// * `headers_dir`: An [`Option<&str>`] specifying the directory containing header files for verification.
/// Must be a directory if provided.
/// * `decompress`: An [`Option<bool>`] indicating whether decompression from `zstd` format is necessary.
/// * `decompress`: A [`Decompression`] enum indicating whether decompression from `zstd` format is necessary.
///
pub fn handle_file(
path: &PathBuf,
output: Option<&str>,
headers_dir: Option<&str>,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<Block>, DecodeError> {
let input_file = BufReader::new(File::open(path).map_err(DecodeError::IoError)?);
// Check if decompression is required and read the file accordingly.
let mut file_contents: Box<dyn Read> = if decompress == Some(true) {
let decompressed_data = decode_all(input_file)
.map_err(|e| DecodeError::IoError(std::io::Error::new(std::io::ErrorKind::Other, e)))?;
Box::new(Cursor::new(decompressed_data))
} else {
Box::new(input_file)
let mut file_contents: Box<dyn Read> = match decompress {
Decompression::Zstd => {
let decompressed_data = decode_all(input_file).map_err(|e| {
DecodeError::IoError(std::io::Error::new(std::io::ErrorKind::Other, e))
})?;
Box::new(Cursor::new(decompressed_data))
}
Decompression::None => Box::new(input_file),
};

let dbin_file = DbinFile::try_from_read(&mut file_contents)?;
Expand Down Expand Up @@ -166,11 +194,10 @@ pub fn handle_file(
/// * `buf`: A byte slice referencing the in-memory content of the flat file to be decoded.
/// * `decompress`: A boolean indicating whether the input buffer should be decompressed.
///
pub fn handle_buf(buf: &[u8], decompress: Option<bool>) -> Result<Vec<Block>, DecodeError> {
let buf = if decompress.unwrap_or(false) {
zstd::decode_all(buf).map_err(|_| DecodeError::DecompressError)?
} else {
buf.to_vec()
pub fn handle_buf(buf: &[u8], decompress: Decompression) -> Result<Vec<Block>, DecodeError> {
let buf = match decompress {
Decompression::Zstd => zstd::decode_all(buf).map_err(|_| DecodeError::DecompressError)?,
Decompression::None => buf.to_vec(),
};

let dbin_file = DbinFile::try_from_read(&mut Cursor::new(buf))?;
Expand Down Expand Up @@ -318,23 +345,16 @@ where

#[cfg(test)]
mod tests {
use prost::Message;

use crate::dbin::DbinFile;
use crate::receipts::check_receipt_root;
use crate::{handle_buf, handle_file, receipts, stream_blocks};
use sf_protos::bstream::v1::Block as BstreamBlock;
use sf_protos::ethereum::r#type::v2::Block;
use std::fs::File;
use std::io::{self, Cursor, Read, Write};
use std::io::{BufReader, BufWriter};
use std::path::PathBuf;
use std::io::{self, BufReader, BufWriter, Cursor, Read, Write};

use super::*;

#[test]
fn test_handle_file() {
let path = PathBuf::from("example0017686312.dbin");

let result = handle_file(&path, None, None, None);
let result = handle_file(&path, None, None, Decompression::None);

assert!(result.is_ok());
}
Expand All @@ -343,7 +363,7 @@ mod tests {
fn test_handle_file_zstd() {
let path = PathBuf::from("./tests/0000000000.dbin.zst");

let result = handle_file(&path, None, None, Some(true));
let result = handle_file(&path, None, None, Decompression::Zstd);

assert!(result.is_ok());
let blocks: Vec<Block> = result.unwrap();
Expand Down Expand Up @@ -410,7 +430,7 @@ mod tests {
.read_to_end(&mut buffer)
.expect("Failed to read file");

let result = handle_buf(&buffer, Some(false));
let result = handle_buf(&buffer, Decompression::None);
assert!(result.is_ok(), "handle_buf should complete successfully");
}

Expand All @@ -426,7 +446,7 @@ mod tests {
.read_to_end(&mut buffer)
.expect("Failed to read file");

let result = handle_buf(&buffer, Some(true));
let result = handle_buf(&buffer, Decompression::Zstd);
assert!(
result.is_ok(),
"handle_buf should complete successfully with decompression"
Expand Down
4 changes: 2 additions & 2 deletions crates/flat-files-decoder/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use clap::{Parser, Subcommand};
use decoder::{decode_flat_files, stream_blocks};
use decoder::{decode_flat_files, stream_blocks, Decompression};
use std::io::{self, BufReader, BufWriter};

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -34,7 +34,7 @@ enum Commands {
output: Option<String>,
#[clap(short, long)]
/// optionally decompress zstd compressed flat files
decompress: Option<bool>,
decompress: Decompression,
},
}
#[tokio::main]
Expand Down
16 changes: 10 additions & 6 deletions crates/flat-head/src/era_verifier.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use decoder::Decompression;
use futures::stream::{FuturesOrdered, StreamExt};
use tokio::task;

Expand All @@ -19,13 +20,13 @@ pub async fn verify_eras(
compatible: Option<String>,
start_epoch: usize,
end_epoch: Option<usize>,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<usize>, anyhow::Error> {
let mut validated_epochs = Vec::new();
let (tx, mut rx) = mpsc::channel(5);

let blocks_store: store::Store = store::new(store_url, decompress.unwrap_or(false), compatible)
.expect("failed to create blocks store");
let blocks_store: store::Store =
store::new(store_url, decompress, compatible).expect("failed to create blocks store");

for epoch in start_epoch..=end_epoch.unwrap_or(start_epoch + 1) {
let tx = tx.clone();
Expand Down Expand Up @@ -75,7 +76,7 @@ pub async fn verify_eras(
async fn get_blocks_from_store(
epoch: usize,
store: &Store,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<Block>, anyhow::Error> {
let start_100_block = epoch * MAX_EPOCH_SIZE;
let end_100_block = (epoch + 1) * MAX_EPOCH_SIZE;
Expand All @@ -95,14 +96,17 @@ async fn extract_100s_blocks(
store: &Store,
start_block: usize,
end_block: usize,
decompress: Option<bool>,
decompress: Decompression,
) -> Result<Vec<Block>, anyhow::Error> {
// Flat files are stored in 100 block files
// So we need to find the 100 block file that contains the start block and the 100 block file that contains the end block
let start_100_block = (start_block / 100) * 100;
let end_100_block = (((end_block as f32) / 100.0).ceil() as usize) * 100;

let zst_extension = if decompress.unwrap() { ".zst" } else { "" };
let zst_extension = match decompress {
Decompression::Zstd => ".zst",
Decompression::None => "",
};

let mut futs = FuturesOrdered::new();

Expand Down
3 changes: 2 additions & 1 deletion crates/flat-head/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::env;

use clap::{Parser, Subcommand};

use decoder::Decompression;
use flat_head::era_verifier::verify_eras;
use trin_validation::accumulator::PreMergeAccumulator;

Expand Down Expand Up @@ -38,7 +39,7 @@ enum Commands {

#[clap(short = 'c', long, default_value = "true")]
// Where to decompress files from zstd or not.
decompress: Option<bool>,
decompress: Decompression,

#[clap(short = 'p', long)]
// indicates if the store_url is compatible with some API. E.g., if `--compatible s3` is used,
Expand Down
4 changes: 2 additions & 2 deletions crates/flat-head/src/s3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use header_accumulator::{
use std::env;
use trin_validation::accumulator::PreMergeAccumulator;

use decoder::handle_buf;
use decoder::{handle_buf, Decompression};

use object_store::{aws::AmazonS3Builder, path::Path, ObjectStore};

Expand Down Expand Up @@ -58,7 +58,7 @@ pub async fn s3_fetch(
let bytes = result.bytes().await.unwrap();

// Use `as_ref` to get a &[u8] from `bytes` and pass it to `handle_buf`
match handle_buf(bytes.as_ref(), Some(false)) {
match handle_buf(bytes.as_ref(), Decompression::None) {
Ok(blocks) => {
let (successful_headers, _): (Vec<_>, Vec<_>) = blocks
.iter()
Expand Down
13 changes: 8 additions & 5 deletions crates/flat-head/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::Context;
use bytes::Bytes;
use decoder::handle_buf;
use decoder::{handle_buf, Decompression};
use object_store::{
aws::AmazonS3Builder, gcp::GoogleCloudStorageBuilder, http::HttpBuilder,
local::LocalFileSystem, path::Path, ClientOptions, ObjectStore,
Expand All @@ -13,7 +13,7 @@ use sf_protos::ethereum::r#type::v2::Block;

pub fn new<S: AsRef<str>>(
store_url: S,
decompress: bool,
decompress: Decompression,
compatible: Option<String>,
) -> Result<Store, anyhow::Error> {
let store_url = store_url.as_ref();
Expand Down Expand Up @@ -123,7 +123,7 @@ pub fn new<S: AsRef<str>>(
pub struct Store {
store: Arc<dyn ObjectStore>,
base: String,
decompress: bool,
decompress: Decompression,
}

impl Store {
Expand Down Expand Up @@ -158,8 +158,11 @@ impl ReadOptions {
}
}

async fn handle_from_bytes(bytes: Bytes, decompress: bool) -> Result<Vec<Block>, ReadError> {
handle_buf(bytes.as_ref(), Some(decompress)).map_err(|e| ReadError::DecodeError(e.to_string()))
async fn handle_from_bytes(
bytes: Bytes,
decompress: Decompression,
) -> Result<Vec<Block>, ReadError> {
handle_buf(bytes.as_ref(), decompress).map_err(|e| ReadError::DecodeError(e.to_string()))
}

// async fn fake_handle_from_stream(
Expand Down
6 changes: 3 additions & 3 deletions crates/header-accumulator/tests/era_validator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fs;

use decoder::decode_flat_files;
use decoder::{decode_flat_files, Decompression};
use header_accumulator::{
era_validator::EraValidator, errors::HeaderAccumulatorError, types::ExtHeaderRecord,
};
Expand All @@ -16,7 +16,7 @@ fn test_era_validate() -> Result<(), HeaderAccumulatorError> {
let mut headers: Vec<ExtHeaderRecord> = Vec::new();
for number in (0..=8200).step_by(100) {
let file_name = format!("tests/ethereum_firehose_first_8200/{:010}.dbin", number);
match decode_flat_files(file_name, None, None, Some(false)) {
match decode_flat_files(file_name, None, None, Decompression::None) {
Ok(blocks) => {
let (successful_headers, _): (Vec<_>, Vec<_>) = blocks
.iter()
Expand Down Expand Up @@ -82,7 +82,7 @@ fn test_era_validate_compressed() -> Result<(), HeaderAccumulatorError> {
let mut headers: Vec<ExtHeaderRecord> = Vec::new();
for number in (0..=8200).step_by(100) {
let file_name = format!("tests/compressed/{:010}.dbin.zst", number);
match decode_flat_files(file_name, None, None, Some(true)) {
match decode_flat_files(file_name, None, None, Decompression::Zstd) {
Ok(blocks) => {
let (successful_headers, _): (Vec<_>, Vec<_>) = blocks
.iter()
Expand Down
4 changes: 2 additions & 2 deletions crates/header-accumulator/tests/inclusion_proof.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use decoder::decode_flat_files;
use decoder::{decode_flat_files, Decompression};
use header_accumulator::{
self,
errors::EraValidateError,
Expand All @@ -17,7 +17,7 @@ fn test_inclusion_proof() -> Result<(), EraValidateError> {
"tests/ethereum_firehose_first_8200/{:010}.dbin",
flat_file_number
);
match decode_flat_files(file_name, None, None, Some(false)) {
match decode_flat_files(file_name, None, None, Decompression::None) {
Ok(blocks) => {
headers.extend(
blocks
Expand Down
4 changes: 2 additions & 2 deletions crates/header-accumulator/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use decoder::decode_flat_files;
use decoder::{decode_flat_files, Decompression};
use ethportal_api::Header;

#[test]
Expand All @@ -7,7 +7,7 @@ fn test_header_from_block() {
"tests/ethereum_firehose_first_8200/0000008200.dbin".to_string(),
None,
None,
Some(false),
Decompression::None,
)
.unwrap();

Expand Down

0 comments on commit 2acb88e

Please sign in to comment.