diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1e6feff..36cdaf0 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -9,7 +9,7 @@ jobs: steps: - uses: actions/checkout@v4 with: - ref: ${{ github.head_ref }} + ref: ${{ github.sha }} - uses: chartboost/ruff-action@v1 with: src: './src/python' diff --git a/Cargo.lock b/Cargo.lock index 02c7be0..394a7bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1032,9 +1032,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.127" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8043c06d9f82bd7271361ed64f415fe5e12a77fdb52e573e7f06a516dea329ad" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 69d191f..aff50d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,5 +15,5 @@ log = "0.4.22" niffler = "2.6.0" pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] } serde = { version = "1.0.210", features = ["derive"] } -serde_json = "1.0" +serde_json = "1.0.128" sourmash = "0.15.1" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 1f74784..20692ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); #[pyclass] #[derive(Serialize, Deserialize, Debug)] +/// Basic KmerCountTable struct, mapping hashes to counts. struct KmerCountTable { counts: HashMap, pub ksize: u8, @@ -30,6 +31,7 @@ struct KmerCountTable { } #[pymethods] +/// Methods on KmerCountTable. impl KmerCountTable { #[new] #[pyo3(signature = (ksize))] @@ -48,11 +50,11 @@ impl KmerCountTable { // TODO: Add function to get canonical kmer using hash key + /// Turn a k-mer into a hashval. fn hash_kmer(&self, kmer: String) -> Result { if kmer.len() as u8 != self.ksize { Err(anyhow!("wrong ksize")) } else { - // mut? let mut hashes = SeqToHashes::new( kmer.as_bytes(), self.ksize.into(), @@ -67,12 +69,14 @@ impl KmerCountTable { } } + /// Increment the count of a hashval by 1. pub fn count_hash(&mut self, hashval: u64) -> u64 { let count = self.counts.entry(hashval).or_insert(0); *count += 1; *count } + /// Increment the count of a k-mer by 1. pub fn count(&mut self, kmer: String) -> PyResult { if kmer.len() as u8 != self.ksize { Err(PyValueError::new_err( @@ -86,6 +90,7 @@ impl KmerCountTable { } } + /// Retrieve the count of a k-mer. pub fn get(&self, kmer: String) -> PyResult { if kmer.len() as u8 != self.ksize { Err(PyValueError::new_err( @@ -103,13 +108,13 @@ impl KmerCountTable { } } - // Get the count for a specific hash value directly + /// Get the count for a specific hash value directly pub fn get_hash(&self, hashval: u64) -> u64 { // Return the count for the hash value, or 0 if it does not exist *self.counts.get(&hashval).unwrap_or(&0) } - // Get counts for a list of hash keys and return an list of counts + /// Get counts for a list of hashvals and return a list of counts pub fn get_hash_array(&self, hash_keys: Vec) -> Vec { // Map each hash key to its count, defaulting to 0 if the key is not present hash_keys.iter().map(|&key| self.get_hash(key)).collect() @@ -187,7 +192,7 @@ impl KmerCountTable { Ok(to_remove.len() as u64) } - // Serialize the KmerCountTable as a JSON string + /// Serialize the KmerCountTable as a JSON string pub fn serialize_json(&self) -> Result { serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e)) } @@ -495,8 +500,8 @@ impl KmerCountTable { } } -// Iterator implementation for KmerCountTable #[pyclass] +/// Iterator implementation for KmerCountTable pub struct KmerCountTableIterator { inner: IntoIter, // Now we own the iterator } diff --git a/src/python/tests/test_serialization.py b/src/python/tests/test_serialization.py index 9772917..6000365 100644 --- a/src/python/tests/test_serialization.py +++ b/src/python/tests/test_serialization.py @@ -1,10 +1,8 @@ import gzip import json import pytest -import tempfile from oxli import KmerCountTable -from os import remove from test_attr import get_version_from_cargo_toml CURRENT_VERSION = get_version_from_cargo_toml() @@ -19,15 +17,6 @@ def sample_kmer_table(): return table -@pytest.fixture -def temp_file(): - """Fixture that provides a temporary file path for testing.""" - with tempfile.NamedTemporaryFile(delete=False, suffix=".json.gz") as temp: - yield temp.name - # Remove the file after the test is done - remove(temp.name) - - def test_serialize_json(sample_kmer_table): """ Test case for the `serialize_json` function. @@ -49,13 +38,15 @@ def test_serialize_json(sample_kmer_table): ), "Version should be serialized." -def test_save_load_roundtrip(sample_kmer_table, temp_file): +def test_save_load_roundtrip(sample_kmer_table, tmp_path): """ Test the save and load functionality. This test saves a KmerCountTable object to a file, then loads it back and verifies that the data in the loaded object matches the original. """ + temp_file = str(tmp_path / "save.json") + # Save the sample KmerCountTable to a Gzip file sample_kmer_table.save(temp_file) @@ -72,12 +63,14 @@ def test_save_load_roundtrip(sample_kmer_table, temp_file): assert list(loaded_table) == list(sample_kmer_table), "All records in same order." -def test_version_warning_on_load_stderr(sample_kmer_table, temp_file, capfd): +def test_version_warning_on_load_stderr(sample_kmer_table, tmp_path, capfd): """ Test that a warning is issued if the loaded object's version is different from the current Oxli version. Uses pytest's capsys fixture to capture stderr output. """ + temp_file = str(tmp_path / "save.json") + # Save the table to a file sample_kmer_table.save(temp_file) @@ -96,3 +89,28 @@ def test_version_warning_on_load_stderr(sample_kmer_table, temp_file, capfd): f"loaded version is 0.0.1, but current version is {CURRENT_VERSION}" in captured.err ) + + +def test_load_bad_json(tmp_path, capfd): + """ + Test that failure happens appropriately when trying to load a bad + JSON file. + """ + temp_file = str(tmp_path / "bad.json") + + with open(temp_file, "wt") as fp: + fp.write("hello, world") + + with pytest.raises(RuntimeError, match="Deserialization error:"): + tb = KmerCountTable.load(temp_file) + + +def test_save_bad_path(sample_kmer_table, tmp_path, capfd): + """ + Test that failure happens appropriately when trying to save to a bad + location. + """ + temp_file = str(tmp_path / "noexist" / "save.json") + + with pytest.raises(OSError, match="No such file or directory"): + sample_kmer_table.save(temp_file)