Skip to content

Commit

Permalink
Add Serialisation Support (#31)
Browse files Browse the repository at this point in the history
* Add toml package as testing dependency.

* Add version attribute to new KmerCountTable instances.

* Test version attribute.

* Track total bases consumed with attr consumed

* Add tests for consumed attr

* Add sum_counts attr to get total counts in table.

* Add tests for sum_counts attr

* Add __len__ method to get total unique kmer count.

* Add tests for __len__ method.

* Add __getitem__ and __setitem__ methods

* Tests for __getitem__ and __setitem__ dunder methods.

* Make KmerCountTable iterable

* Add tests for __iter__

* Add serde and flate2

* Add serialisation support

* Add serialisation tests

* Add niffler

* sort imports

* Use niffler for gzip read/write

* Style fixes by Ruff

* bump serde

* MRG: update serialization code  (#49)

* Do format updates on top of current commit rather than head. (#46)

* update serde_json version

* switch to using built-in temp_path

* write two additional tests

* Style fixes by Ruff

* clean up docstrings viz cargo doc --document-private-items

---------

Co-authored-by: Adam Taranto <[email protected]>
Co-authored-by: ctb <[email protected]>

---------

Co-authored-by: Adamtaranto <[email protected]>
Co-authored-by: C. Titus Brown <[email protected]>
Co-authored-by: ctb <[email protected]>
  • Loading branch information
4 people authored Sep 23, 2024
1 parent bef0bbd commit 7b49c89
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 21 deletions.
115 changes: 109 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ name = "oxli"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] }
sourmash = "0.15.1"
anyhow = "1.0.89"
log = "0.4.22"
env_logger = "0.11.5"
log = "0.4.22"
niffler = "2.6.0"
pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] }
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0.128"
sourmash = "0.15.1"
82 changes: 70 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
// Standard library imports
use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
//use std::path::Path;

// External crate imports
use anyhow::{anyhow, Result};
use log::debug;
use pyo3::exceptions::PyValueError;
use niffler::compression::Format;
use niffler::get_writer;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use pyo3::PyResult;
use serde::{Deserialize, Serialize};
use sourmash::encodings::HashFunctions;
use sourmash::signature::SeqToHashes;

use pyo3::PyResult;
use std::fs::File;
use std::io::{BufWriter, Write};

// Set version variable
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[pyclass]
#[derive(Serialize, Deserialize, Debug)]
/// Basic KmerCountTable struct, mapping hashes to counts.
struct KmerCountTable {
counts: HashMap<u64, u64>,
pub ksize: u8,
Expand All @@ -26,6 +31,7 @@ struct KmerCountTable {
}

#[pymethods]
/// Methods on KmerCountTable.
impl KmerCountTable {
#[new]
#[pyo3(signature = (ksize))]
Expand All @@ -44,11 +50,11 @@ impl KmerCountTable {

// TODO: Add function to get canonical kmer using hash key

/// Turn a k-mer into a hashval.
fn hash_kmer(&self, kmer: String) -> Result<u64> {
if kmer.len() as u8 != self.ksize {
Err(anyhow!("wrong ksize"))
} else {
// mut?
let mut hashes = SeqToHashes::new(
kmer.as_bytes(),
self.ksize.into(),
Expand All @@ -63,12 +69,14 @@ impl KmerCountTable {
}
}

/// Increment the count of a hashval by 1.
pub fn count_hash(&mut self, hashval: u64) -> u64 {
let count = self.counts.entry(hashval).or_insert(0);
*count += 1;
*count
}

/// Increment the count of a k-mer by 1.
pub fn count(&mut self, kmer: String) -> PyResult<u64> {
if kmer.len() as u8 != self.ksize {
Err(PyValueError::new_err(
Expand All @@ -82,6 +90,7 @@ impl KmerCountTable {
}
}

/// Retrieve the count of a k-mer.
pub fn get(&self, kmer: String) -> PyResult<u64> {
if kmer.len() as u8 != self.ksize {
Err(PyValueError::new_err(
Expand All @@ -99,13 +108,13 @@ impl KmerCountTable {
}
}

// Get the count for a specific hash value directly
/// Get the count for a specific hash value directly
pub fn get_hash(&self, hashval: u64) -> u64 {
// Return the count for the hash value, or 0 if it does not exist
*self.counts.get(&hashval).unwrap_or(&0)
}

// Get counts for a list of hash keys and return an list of counts
/// Get counts for a list of hashvals and return a list of counts
pub fn get_hash_array(&self, hash_keys: Vec<u64>) -> Vec<u64> {
// Map each hash key to its count, defaulting to 0 if the key is not present
hash_keys.iter().map(|&key| self.get_hash(key)).collect()
Expand Down Expand Up @@ -183,11 +192,60 @@ impl KmerCountTable {
Ok(to_remove.len() as u64)
}

// TODO: Serialize the KmerCountTable instance to a JSON string.
/// Serialize the KmerCountTable as a JSON string
pub fn serialize_json(&self) -> Result<String> {
serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
}

/// Save the KmerCountTable to a compressed file using Niffler.
pub fn save(&self, filepath: &str) -> PyResult<()> {
// Open the file for writing
let file = File::create(filepath).map_err(|e| PyIOError::new_err(e.to_string()))?;

// Create a Gzipped writer with niffler, using the default compression level
let writer = BufWriter::new(file);
let mut writer = get_writer(Box::new(writer), Format::Gzip, niffler::level::Level::One)
.map_err(|e| PyIOError::new_err(e.to_string()))?;

// Serialize the KmerCountTable to JSON
let json_data = self.serialize_json()?;

// Write the serialized JSON to the compressed file
writer
.write_all(json_data.as_bytes())
.map_err(|e| PyIOError::new_err(e.to_string()))?;

Ok(())
}

#[staticmethod]
/// Load a KmerCountTable from a compressed file using Niffler.
pub fn load(filepath: &str) -> Result<KmerCountTable> {
// Open the file for reading
let file = File::open(filepath)?;

// Use Niffler to get a reader that detects the compression format
let reader = BufReader::new(file);
let (mut reader, _format) = niffler::get_reader(Box::new(reader))?;

// TODO: Compress JSON string with gzip and save to file
// Read the decompressed data into a string
let mut decompressed_data = String::new();
reader.read_to_string(&mut decompressed_data)?;

// TODO: Static method to load KmerCountTable from serialized JSON. Yield new object.
// Deserialize the JSON string to a KmerCountTable
let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data)
.map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?;

// Check version compatibility and issue a warning if necessary
if loaded_table.version != VERSION {
eprintln!(
"Version mismatch: loaded version is {}, but current version is {}",
loaded_table.version, VERSION
);
}

Ok(loaded_table)
}

/// Dump (hash,count) pairs, optional sorted by count or hash key.
///
Expand Down Expand Up @@ -442,8 +500,8 @@ impl KmerCountTable {
}
}

// Iterator implementation for KmerCountTable
#[pyclass]
/// Iterator implementation for KmerCountTable
pub struct KmerCountTableIterator {
inner: IntoIter<u64, u64>, // Now we own the iterator
}
Expand Down
Loading

0 comments on commit 7b49c89

Please sign in to comment.