From 7b49c89db038be8cb552f60696d8f7507674a1f7 Mon Sep 17 00:00:00 2001 From: Adam Taranto Date: Mon, 23 Sep 2024 12:54:11 +1000 Subject: [PATCH] Add Serialisation Support (#31) * Add toml package as testing dependency. * Add version attribute to new KmerCountTable instances. * Test version attribute. * Track total bases consumed with attr consumed * Add tests for consumed attr * Add sum_counts attr to get total counts in table. * Add tests for sum_counts attr * Add __len__ method to get total unique kmer count. * Add tests for __len__ method. * Add __getitem__ and __setitem__ methods * Tests for __getitem__ and __setitem__ dunder methods. * Make KmerCountTable iterable * Add tests for __iter__ * Add serde and flate2 * Add serialisation support * Add serialisation tests * Add niffler * sort imports * Use niffler for gzip read/write * Style fixes by Ruff * bump serde * MRG: update serialization code (#49) * Do format updates on top of current commit rather than head. (#46) * update serde_json version * switch to using built-in temp_path * write two additional tests * Style fixes by Ruff * clean up docstrings viz cargo doc --document-private-items --------- Co-authored-by: Adam Taranto Co-authored-by: ctb --------- Co-authored-by: Adamtaranto Co-authored-by: C. Titus Brown Co-authored-by: ctb --- Cargo.lock | 115 ++++++++++++++++++++++-- Cargo.toml | 9 +- src/lib.rs | 82 ++++++++++++++--- src/python/tests/test_serialization.py | 116 +++++++++++++++++++++++++ 4 files changed, 301 insertions(+), 21 deletions(-) create mode 100644 src/python/tests/test_serialization.py diff --git a/Cargo.lock b/Cargo.lock index 78852b9..394a7bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,16 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b7e4c2464d97fe331d41de9d5db0def0a96f4d823b8b32a2efd503578988973" +[[package]] +name = "bgzip" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b64fd8980fb64af5951bc05de7772b598150a6f7eac42ec17f73e8489915f99b" +dependencies = [ + "flate2", + "thiserror", +] + [[package]] name = "binary-merge" version = "0.1.2" @@ -153,6 +163,27 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "camino" version = "1.1.9" @@ -168,6 +199,8 @@ version = "1.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -428,6 +461,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.70" @@ -449,6 +491,26 @@ version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" +[[package]] +name = "liblzma" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7c45fc6fcf5b527d3cf89c1dee8c327943984b0dc8bfcf6e100473b00969e63" +dependencies = [ + "liblzma-sys", +] + +[[package]] +name = "liblzma-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63117d31458acdb7b406f6c60090aa8e1e7cd6e283f8ee02ce585ed68c53fe39" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "libm" version = "0.2.8" @@ -562,9 +624,13 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd625dd485c2d20bdb98d7ec364f798b256ac09997ef18b4274be2168f53a647" dependencies = [ + "bgzip", + "bzip2", "cfg-if", "flate2", + "liblzma", "thiserror", + "zstd", ] [[package]] @@ -660,7 +726,10 @@ dependencies = [ "anyhow", "env_logger", "log", + "niffler", "pyo3", + "serde", + "serde_json", "sourmash", ] @@ -686,6 +755,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "portable-atomic" version = "1.7.0" @@ -937,18 +1012,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.209" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", @@ -957,9 +1032,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.127" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", "memchr", @@ -1384,3 +1459,31 @@ dependencies = [ "quote", "syn 2.0.77", ] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 8d60611..aff50d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,8 +9,11 @@ name = "oxli" crate-type = ["cdylib"] [dependencies] -pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] } -sourmash = "0.15.1" anyhow = "1.0.89" -log = "0.4.22" env_logger = "0.11.5" +log = "0.4.22" +niffler = "2.6.0" +pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] } +serde = { version = "1.0.210", features = ["derive"] } +serde_json = "1.0.128" +sourmash = "0.15.1" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 823724a..66e1b9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,28 @@ // Standard library imports use std::collections::hash_map::IntoIter; use std::collections::{HashMap, HashSet}; +use std::fs::File; +use std::io::{BufReader, BufWriter, Write}; +//use std::path::Path; // External crate imports use anyhow::{anyhow, Result}; use log::debug; -use pyo3::exceptions::PyValueError; +use niffler::compression::Format; +use niffler::get_writer; +use pyo3::exceptions::{PyIOError, PyValueError}; use pyo3::prelude::*; +use pyo3::PyResult; +use serde::{Deserialize, Serialize}; use sourmash::encodings::HashFunctions; use sourmash::signature::SeqToHashes; -use pyo3::PyResult; -use std::fs::File; -use std::io::{BufWriter, Write}; - // Set version variable const VERSION: &str = env!("CARGO_PKG_VERSION"); #[pyclass] +#[derive(Serialize, Deserialize, Debug)] +/// Basic KmerCountTable struct, mapping hashes to counts. struct KmerCountTable { counts: HashMap, pub ksize: u8, @@ -26,6 +31,7 @@ struct KmerCountTable { } #[pymethods] +/// Methods on KmerCountTable. impl KmerCountTable { #[new] #[pyo3(signature = (ksize))] @@ -44,11 +50,11 @@ impl KmerCountTable { // TODO: Add function to get canonical kmer using hash key + /// Turn a k-mer into a hashval. fn hash_kmer(&self, kmer: String) -> Result { if kmer.len() as u8 != self.ksize { Err(anyhow!("wrong ksize")) } else { - // mut? let mut hashes = SeqToHashes::new( kmer.as_bytes(), self.ksize.into(), @@ -63,12 +69,14 @@ impl KmerCountTable { } } + /// Increment the count of a hashval by 1. pub fn count_hash(&mut self, hashval: u64) -> u64 { let count = self.counts.entry(hashval).or_insert(0); *count += 1; *count } + /// Increment the count of a k-mer by 1. pub fn count(&mut self, kmer: String) -> PyResult { if kmer.len() as u8 != self.ksize { Err(PyValueError::new_err( @@ -82,6 +90,7 @@ impl KmerCountTable { } } + /// Retrieve the count of a k-mer. pub fn get(&self, kmer: String) -> PyResult { if kmer.len() as u8 != self.ksize { Err(PyValueError::new_err( @@ -99,13 +108,13 @@ impl KmerCountTable { } } - // Get the count for a specific hash value directly + /// Get the count for a specific hash value directly pub fn get_hash(&self, hashval: u64) -> u64 { // Return the count for the hash value, or 0 if it does not exist *self.counts.get(&hashval).unwrap_or(&0) } - // Get counts for a list of hash keys and return an list of counts + /// Get counts for a list of hashvals and return a list of counts pub fn get_hash_array(&self, hash_keys: Vec) -> Vec { // Map each hash key to its count, defaulting to 0 if the key is not present hash_keys.iter().map(|&key| self.get_hash(key)).collect() @@ -183,11 +192,60 @@ impl KmerCountTable { Ok(to_remove.len() as u64) } - // TODO: Serialize the KmerCountTable instance to a JSON string. + /// Serialize the KmerCountTable as a JSON string + pub fn serialize_json(&self) -> Result { + serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e)) + } + + /// Save the KmerCountTable to a compressed file using Niffler. + pub fn save(&self, filepath: &str) -> PyResult<()> { + // Open the file for writing + let file = File::create(filepath).map_err(|e| PyIOError::new_err(e.to_string()))?; + + // Create a Gzipped writer with niffler, using the default compression level + let writer = BufWriter::new(file); + let mut writer = get_writer(Box::new(writer), Format::Gzip, niffler::level::Level::One) + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + // Serialize the KmerCountTable to JSON + let json_data = self.serialize_json()?; + + // Write the serialized JSON to the compressed file + writer + .write_all(json_data.as_bytes()) + .map_err(|e| PyIOError::new_err(e.to_string()))?; + + Ok(()) + } + + #[staticmethod] + /// Load a KmerCountTable from a compressed file using Niffler. + pub fn load(filepath: &str) -> Result { + // Open the file for reading + let file = File::open(filepath)?; + + // Use Niffler to get a reader that detects the compression format + let reader = BufReader::new(file); + let (mut reader, _format) = niffler::get_reader(Box::new(reader))?; - // TODO: Compress JSON string with gzip and save to file + // Read the decompressed data into a string + let mut decompressed_data = String::new(); + reader.read_to_string(&mut decompressed_data)?; - // TODO: Static method to load KmerCountTable from serialized JSON. Yield new object. + // Deserialize the JSON string to a KmerCountTable + let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data) + .map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))?; + + // Check version compatibility and issue a warning if necessary + if loaded_table.version != VERSION { + eprintln!( + "Version mismatch: loaded version is {}, but current version is {}", + loaded_table.version, VERSION + ); + } + + Ok(loaded_table) + } /// Dump (hash,count) pairs, optional sorted by count or hash key. /// @@ -442,8 +500,8 @@ impl KmerCountTable { } } -// Iterator implementation for KmerCountTable #[pyclass] +/// Iterator implementation for KmerCountTable pub struct KmerCountTableIterator { inner: IntoIter, // Now we own the iterator } diff --git a/src/python/tests/test_serialization.py b/src/python/tests/test_serialization.py new file mode 100644 index 0000000..6000365 --- /dev/null +++ b/src/python/tests/test_serialization.py @@ -0,0 +1,116 @@ +import gzip +import json +import pytest + +from oxli import KmerCountTable +from test_attr import get_version_from_cargo_toml + +CURRENT_VERSION = get_version_from_cargo_toml() + + +@pytest.fixture +def sample_kmer_table(): + """Fixture that provides a sample KmerCountTable object.""" + table = KmerCountTable(ksize=4) + table.count("AAAA") + table.count("TTTT") + return table + + +def test_serialize_json(sample_kmer_table): + """ + Test case for the `serialize_json` function. + + This test verifies that the `serialize_json` function correctly serializes a + KmerCountTable object into a JSON string. + """ + # Serialize the KmerCountTable object to JSON + json_data = sample_kmer_table.serialize_json() + + # Convert back to dict to verify correctness + json_dict = json.loads(json_data) + + # Check that essential attributes exist + assert "counts" in json_dict, "Counts should be serialized." + assert json_dict["ksize"] == 4, "Ksize should be correctly serialized." + assert ( + sample_kmer_table.version == json_dict["version"] + ), "Version should be serialized." + + +def test_save_load_roundtrip(sample_kmer_table, tmp_path): + """ + Test the save and load functionality. + + This test saves a KmerCountTable object to a file, then loads it back and + verifies that the data in the loaded object matches the original. + """ + temp_file = str(tmp_path / "save.json") + + # Save the sample KmerCountTable to a Gzip file + sample_kmer_table.save(temp_file) + + # Load the KmerCountTable from the file + loaded_table = KmerCountTable.load(temp_file) + + # Verify that the loaded data matches the original + assert loaded_table.get("AAAA") == sample_kmer_table.get( + "AAAA" + ), "Counts should be preserved after loading." + assert loaded_table.get("TTTT") == sample_kmer_table.get( + "TTTT" + ), "Counts for reverse complement should be preserved." + assert list(loaded_table) == list(sample_kmer_table), "All records in same order." + + +def test_version_warning_on_load_stderr(sample_kmer_table, tmp_path, capfd): + """ + Test that a warning is issued if the loaded object's version is different from the current Oxli version. + + Uses pytest's capsys fixture to capture stderr output. + """ + temp_file = str(tmp_path / "save.json") + + # Save the table to a file + sample_kmer_table.save(temp_file) + + # Mock the current version to simulate a version mismatch + mock_json = sample_kmer_table.serialize_json().replace(CURRENT_VERSION, "0.0.1") + with gzip.open(temp_file, "wt") as f: + json.dump(json.loads(mock_json), f) + + # Capture stderr output + loaded_table = KmerCountTable.load(temp_file) + captured = capfd.readouterr() + + # Check stderr for the version mismatch warning + assert "Version mismatch" in captured.err + assert ( + f"loaded version is 0.0.1, but current version is {CURRENT_VERSION}" + in captured.err + ) + + +def test_load_bad_json(tmp_path, capfd): + """ + Test that failure happens appropriately when trying to load a bad + JSON file. + """ + temp_file = str(tmp_path / "bad.json") + + with open(temp_file, "wt") as fp: + fp.write("hello, world") + + with pytest.raises(RuntimeError, match="Deserialization error:"): + tb = KmerCountTable.load(temp_file) + + +def test_save_bad_path(sample_kmer_table, tmp_path, capfd): + """ + Test that failure happens appropriately when trying to save to a bad + location. + """ + temp_file = str(tmp_path / "noexist" / "save.json") + + with pytest.raises(OSError, match="No such file or directory"): + sample_kmer_table.save(temp_file)