Skip to content

Commit

Permalink
Use niffler for gzip read/write
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamtaranto committed Sep 20, 2024
1 parent b167d16 commit 6cf462d
Showing 1 changed file with 38 additions and 37 deletions.
75 changes: 38 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
// Standard library imports
use std::collections::{HashMap, HashSet};
use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use std::io::{BufReader, BufWriter, Write};
//use std::path::Path;

// External crate imports
use anyhow::{anyhow, Result};
use log::debug;
use niffler::{self, Compression, Format};
use niffler::compression::Format;
use niffler::get_writer;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use pyo3::PyResult;
Expand Down Expand Up @@ -188,54 +188,55 @@ impl KmerCountTable {
}

// Serialize the KmerCountTable as a JSON string
pub fn serialize_json(&self) -> PyResult<String> {
serde_json::to_string(&self)
.map_err(|e| PyValueError::new_err(format!("Serialization failed: {}", e)))
pub fn serialize_json(&self) -> Result<String> {
serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
}

// Save the KmerCountTable to a Gzip-compressed JSON file
pub fn save(&self, filepath: String) -> PyResult<()> {
let serialized = self.serialize_json()?;
/// Save the KmerCountTable to a compressed file using Niffler.
pub fn save(&self, filepath: &str) -> PyResult<()> {
// Open the file for writing
let file = File::create(filepath).map_err(|e| PyIOError::new_err(e.to_string()))?;

// Create a Gzipped writer with niffler, using the default compression level
let writer = BufWriter::new(file);
let mut writer = get_writer(Box::new(writer), Format::Gzip, niffler::level::Level::One)
.map_err(|e| PyIOError::new_err(e.to_string()))?;

let path = Path::new(&filepath);
let file = File::create(path)
.map_err(|e| PyValueError::new_err(format!("File creation failed: {}", e)))?;
let mut encoder = GzEncoder::new(file, Compression::default());
// Serialize the KmerCountTable to JSON
let json_data = self.serialize_json()?;

encoder
.write_all(serialized.as_bytes())
.map_err(|e| PyValueError::new_err(format!("Failed to write to file: {}", e)))?;
encoder
.finish()
.map_err(|e| PyValueError::new_err(format!("Failed to finish compression: {}", e)))?;
// Write the serialized JSON to the compressed file
writer
.write_all(json_data.as_bytes())
.map_err(|e| PyIOError::new_err(e.to_string()))?;

Ok(())
}

#[staticmethod]
pub fn load(filepath: String) -> PyResult<Self> {
let path = Path::new(&filepath);
let file = File::open(path)
.map_err(|e| PyValueError::new_err(format!("File open failed: {}", e)))?;
let mut decoder = GzDecoder::new(file);
let mut decompressed_data = String::new();
/// Load a KmerCountTable from a compressed file using Niffler.
pub fn load(filepath: &str) -> Result<KmerCountTable> {
// Open the file for reading
let file = File::open(filepath)?;

// Use Niffler to get a reader that detects the compression format
let reader = BufReader::new(file);
let (mut reader, _format) = niffler::get_reader(Box::new(reader))?;

decoder
.read_to_string(&mut decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Decompression failed: {}", e)))?;
// Read the decompressed data into a string
let mut decompressed_data = String::new();
reader.read_to_string(&mut decompressed_data)?;

// Deserialize the JSON string to a KmerCountTable
let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Deserialization failed: {}", e)))?;
.map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;

// Check for version mismatch
// Check version compatibility and issue a warning if necessary
if loaded_table.version != VERSION {
eprint!(
"Warning: Version mismatch: loaded version is {}, but current version is {}",
eprintln!(
"Version mismatch: loaded version is {}, but current version is {}",
loaded_table.version, VERSION
);
io::stderr()
.flush()
.map_err(|e| PyIOError::new_err(e.to_string()))?; // Flush stderr
}

Ok(loaded_table)
Expand Down

0 comments on commit 6cf462d

Please sign in to comment.