From 6cf462d5d0300a48a85b8d3cfd954be2a54cdb88 Mon Sep 17 00:00:00 2001 From: Adam Taranto Date: Fri, 20 Sep 2024 23:47:34 +1000 Subject: [PATCH] Use niffler for gzip read/write --- src/lib.rs | 75 +++++++++++++++++++++++++++--------------------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ad15c63..1f74784 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,15 @@ // Standard library imports -use std::collections::{HashMap, HashSet}; use std::collections::hash_map::IntoIter; +use std::collections::{HashMap, HashSet}; use std::fs::File; -use std::io; -use std::io::{BufReader, BufWriter, Read, Write}; -use std::path::Path; +use std::io::{BufReader, BufWriter, Write}; +//use std::path::Path; // External crate imports use anyhow::{anyhow, Result}; use log::debug; -use niffler::{self, Compression, Format}; +use niffler::compression::Format; +use niffler::get_writer; use pyo3::exceptions::{PyIOError, PyValueError}; use pyo3::prelude::*; use pyo3::PyResult; @@ -188,54 +188,55 @@ impl KmerCountTable { } // Serialize the KmerCountTable as a JSON string - pub fn serialize_json(&self) -> PyResult { - serde_json::to_string(&self) - .map_err(|e| PyValueError::new_err(format!("Serialization failed: {}", e))) + pub fn serialize_json(&self) -> Result { + serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e)) } - // Save the KmerCountTable to a Gzip-compressed JSON file - pub fn save(&self, filepath: String) -> PyResult<()> { - let serialized = self.serialize_json()?; + /// Save the KmerCountTable to a compressed file using Niffler. + pub fn save(&self, filepath: &str) -> PyResult<()> { + // Open the file for writing + let file = File::create(filepath).map_err(|e| PyIOError::new_err(e.to_string()))?; + + // Create a Gzipped writer with niffler, using the default compression level + let writer = BufWriter::new(file); + let mut writer = get_writer(Box::new(writer), Format::Gzip, niffler::level::Level::One) + .map_err(|e| PyIOError::new_err(e.to_string()))?; - let path = Path::new(&filepath); - let file = File::create(path) - .map_err(|e| PyValueError::new_err(format!("File creation failed: {}", e)))?; - let mut encoder = GzEncoder::new(file, Compression::default()); + // Serialize the KmerCountTable to JSON + let json_data = self.serialize_json()?; - encoder - .write_all(serialized.as_bytes()) - .map_err(|e| PyValueError::new_err(format!("Failed to write to file: {}", e)))?; - encoder - .finish() - .map_err(|e| PyValueError::new_err(format!("Failed to finish compression: {}", e)))?; + // Write the serialized JSON to the compressed file + writer + .write_all(json_data.as_bytes()) + .map_err(|e| PyIOError::new_err(e.to_string()))?; Ok(()) } #[staticmethod] - pub fn load(filepath: String) -> PyResult { - let path = Path::new(&filepath); - let file = File::open(path) - .map_err(|e| PyValueError::new_err(format!("File open failed: {}", e)))?; - let mut decoder = GzDecoder::new(file); - let mut decompressed_data = String::new(); + /// Load a KmerCountTable from a compressed file using Niffler. + pub fn load(filepath: &str) -> Result { + // Open the file for reading + let file = File::open(filepath)?; + + // Use Niffler to get a reader that detects the compression format + let reader = BufReader::new(file); + let (mut reader, _format) = niffler::get_reader(Box::new(reader))?; - decoder - .read_to_string(&mut decompressed_data) - .map_err(|e| PyValueError::new_err(format!("Decompression failed: {}", e)))?; + // Read the decompressed data into a string + let mut decompressed_data = String::new(); + reader.read_to_string(&mut decompressed_data)?; + // Deserialize the JSON string to a KmerCountTable let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data) - .map_err(|e| PyValueError::new_err(format!("Deserialization failed: {}", e)))?; + .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?; - // Check for version mismatch + // Check version compatibility and issue a warning if necessary if loaded_table.version != VERSION { - eprint!( - "Warning: Version mismatch: loaded version is {}, but current version is {}", + eprintln!( + "Version mismatch: loaded version is {}, but current version is {}", loaded_table.version, VERSION ); - io::stderr() - .flush() - .map_err(|e| PyIOError::new_err(e.to_string()))?; // Flush stderr } Ok(loaded_table)