Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: update serialization code #49

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}
ref: ${{ github.sha }}
- uses: chartboost/ruff-action@v1
with:
src: './src/python'
Expand Down
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ log = "0.4.22"
niffler = "2.6.0"
pyo3 = { version="0.22.3", features = ["extension-module", "anyhow"] }
serde = { version = "1.0.210", features = ["derive"] }
serde_json = "1.0"
serde_json = "1.0.128"
sourmash = "0.15.1"
15 changes: 10 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const VERSION: &str = env!("CARGO_PKG_VERSION");

#[pyclass]
#[derive(Serialize, Deserialize, Debug)]
/// Basic KmerCountTable struct, mapping hashes to counts.
struct KmerCountTable {
counts: HashMap<u64, u64>,
pub ksize: u8,
Expand All @@ -30,6 +31,7 @@ struct KmerCountTable {
}

#[pymethods]
/// Methods on KmerCountTable.
impl KmerCountTable {
#[new]
#[pyo3(signature = (ksize))]
Expand All @@ -48,11 +50,11 @@ impl KmerCountTable {

// TODO: Add function to get canonical kmer using hash key

/// Turn a k-mer into a hashval.
fn hash_kmer(&self, kmer: String) -> Result<u64> {
if kmer.len() as u8 != self.ksize {
Err(anyhow!("wrong ksize"))
} else {
// mut?
let mut hashes = SeqToHashes::new(
kmer.as_bytes(),
self.ksize.into(),
Expand All @@ -67,12 +69,14 @@ impl KmerCountTable {
}
}

/// Increment the count of a hashval by 1.
pub fn count_hash(&mut self, hashval: u64) -> u64 {
let count = self.counts.entry(hashval).or_insert(0);
*count += 1;
*count
}

/// Increment the count of a k-mer by 1.
pub fn count(&mut self, kmer: String) -> PyResult<u64> {
if kmer.len() as u8 != self.ksize {
Err(PyValueError::new_err(
Expand All @@ -86,6 +90,7 @@ impl KmerCountTable {
}
}

/// Retrieve the count of a k-mer.
pub fn get(&self, kmer: String) -> PyResult<u64> {
if kmer.len() as u8 != self.ksize {
Err(PyValueError::new_err(
Expand All @@ -103,13 +108,13 @@ impl KmerCountTable {
}
}

// Get the count for a specific hash value directly
/// Get the count for a specific hash value directly
pub fn get_hash(&self, hashval: u64) -> u64 {
// Return the count for the hash value, or 0 if it does not exist
*self.counts.get(&hashval).unwrap_or(&0)
}

// Get counts for a list of hash keys and return an list of counts
/// Get counts for a list of hashvals and return a list of counts
pub fn get_hash_array(&self, hash_keys: Vec<u64>) -> Vec<u64> {
// Map each hash key to its count, defaulting to 0 if the key is not present
hash_keys.iter().map(|&key| self.get_hash(key)).collect()
Expand Down Expand Up @@ -187,7 +192,7 @@ impl KmerCountTable {
Ok(to_remove.len() as u64)
}

// Serialize the KmerCountTable as a JSON string
/// Serialize the KmerCountTable as a JSON string
pub fn serialize_json(&self) -> Result<String> {
serde_json::to_string(&self).map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
}
Expand Down Expand Up @@ -495,8 +500,8 @@ impl KmerCountTable {
}
}

// Iterator implementation for KmerCountTable
#[pyclass]
/// Iterator implementation for KmerCountTable
pub struct KmerCountTableIterator {
inner: IntoIter<u64, u64>, // Now we own the iterator
}
Expand Down
44 changes: 31 additions & 13 deletions src/python/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import gzip
import json
import pytest
import tempfile

from oxli import KmerCountTable
from os import remove
from test_attr import get_version_from_cargo_toml

CURRENT_VERSION = get_version_from_cargo_toml()
Expand All @@ -19,15 +17,6 @@ def sample_kmer_table():
return table


@pytest.fixture
def temp_file():
"""Fixture that provides a temporary file path for testing."""
with tempfile.NamedTemporaryFile(delete=False, suffix=".json.gz") as temp:
yield temp.name
# Remove the file after the test is done
remove(temp.name)


def test_serialize_json(sample_kmer_table):
"""
Test case for the `serialize_json` function.
Expand All @@ -49,13 +38,15 @@ def test_serialize_json(sample_kmer_table):
), "Version should be serialized."


def test_save_load_roundtrip(sample_kmer_table, temp_file):
def test_save_load_roundtrip(sample_kmer_table, tmp_path):
"""
Test the save and load functionality.

This test saves a KmerCountTable object to a file, then loads it back and
verifies that the data in the loaded object matches the original.
"""
temp_file = str(tmp_path / "save.json")

# Save the sample KmerCountTable to a Gzip file
sample_kmer_table.save(temp_file)

Expand All @@ -72,12 +63,14 @@ def test_save_load_roundtrip(sample_kmer_table, temp_file):
assert list(loaded_table) == list(sample_kmer_table), "All records in same order."


def test_version_warning_on_load_stderr(sample_kmer_table, temp_file, capfd):
def test_version_warning_on_load_stderr(sample_kmer_table, tmp_path, capfd):
"""
Test that a warning is issued if the loaded object's version is different from the current Oxli version.

Uses pytest's capsys fixture to capture stderr output.
"""
temp_file = str(tmp_path / "save.json")

# Save the table to a file
sample_kmer_table.save(temp_file)

Expand All @@ -96,3 +89,28 @@ def test_version_warning_on_load_stderr(sample_kmer_table, temp_file, capfd):
f"loaded version is 0.0.1, but current version is {CURRENT_VERSION}"
in captured.err
)


def test_load_bad_json(tmp_path, capfd):
"""
Test that failure happens appropriately when trying to load a bad
JSON file.
"""
temp_file = str(tmp_path / "bad.json")

with open(temp_file, "wt") as fp:
fp.write("hello, world")

with pytest.raises(RuntimeError, match="Deserialization error:"):
tb = KmerCountTable.load(temp_file)


def test_save_bad_path(sample_kmer_table, tmp_path, capfd):
"""
Test that failure happens appropriately when trying to save to a bad
location.
"""
temp_file = str(tmp_path / "noexist" / "save.json")

with pytest.raises(OSError, match="No such file or directory"):
sample_kmer_table.save(temp_file)
Loading