Skip to content

Commit

Permalink
Add serialisation support
Browse files Browse the repository at this point in the history
  • Loading branch information
Adamtaranto committed Sep 16, 2024
1 parent 06b56fa commit 3c1faed
Showing 1 changed file with 62 additions and 4 deletions.
66 changes: 62 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
// Standard library imports
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;

// External crate imports
use anyhow::{anyhow, Result};
use log::debug;
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use sourmash::encodings::HashFunctions;
use sourmash::signature::SeqToHashes;

use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use serde::{Deserialize, Serialize};
use std::io;

// Set version variable
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[pyclass]
#[derive(Serialize, Deserialize, Debug)]
struct KmerCountTable {
counts: HashMap<u64, u64>,
pub ksize: u8,
Expand Down Expand Up @@ -117,11 +127,59 @@ impl KmerCountTable {

// TODO: Add "maxcut". Remove counts above an maximum cutoff.

// TODO: Serialize the KmerCountTable instance to a JSON string.
// Serialize the KmerCountTable as a JSON string
pub fn serialize_json(&self) -> PyResult<String> {
serde_json::to_string(&self)
.map_err(|e| PyValueError::new_err(format!("Serialization failed: {}", e)))
}

// Save the KmerCountTable to a Gzip-compressed JSON file
pub fn save(&self, filepath: String) -> PyResult<()> {
let serialized = self.serialize_json()?;

// TODO: Compress JSON string with gzip and save to file
let path = Path::new(&filepath);
let file = File::create(path)
.map_err(|e| PyValueError::new_err(format!("File creation failed: {}", e)))?;
let mut encoder = GzEncoder::new(file, Compression::default());

// TODO: Static method to load KmerCountTable from serialized JSON. Yield new object.
encoder
.write_all(serialized.as_bytes())
.map_err(|e| PyValueError::new_err(format!("Failed to write to file: {}", e)))?;
encoder
.finish()
.map_err(|e| PyValueError::new_err(format!("Failed to finish compression: {}", e)))?;

Ok(())
}

#[staticmethod]
pub fn load(filepath: String) -> PyResult<Self> {
let path = Path::new(&filepath);
let file = File::open(path)
.map_err(|e| PyValueError::new_err(format!("File open failed: {}", e)))?;
let mut decoder = GzDecoder::new(file);
let mut decompressed_data = String::new();

decoder
.read_to_string(&mut decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Decompression failed: {}", e)))?;

let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Deserialization failed: {}", e)))?;

// Check for version mismatch
if loaded_table.version != VERSION {
eprint!(
"Warning: Version mismatch: loaded version is {}, but current version is {}",
loaded_table.version, VERSION
);
io::stderr()
.flush()
.map_err(|e| PyIOError::new_err(e.to_string()))?; // Flush stderr
}

Ok(loaded_table)
}

// TODO: Add method "dump"
// Output tab delimited kmer:count pairs
Expand Down

0 comments on commit 3c1faed

Please sign in to comment.