Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Serialisation Support #31

Merged
merged 25 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9d6c130
Add toml package as testing dependency.
Adamtaranto Sep 14, 2024
50f1ba9
Add version attribute to new KmerCountTable instances.
Adamtaranto Sep 14, 2024
efea9a8
Test version attribute.
Adamtaranto Sep 14, 2024
94c9126
Track total bases consumed with attr consumed
Adamtaranto Sep 14, 2024
9b943df
Add tests for consumed attr
Adamtaranto Sep 14, 2024
a7c4bf0
Add sum_counts attr to get total counts in table.
Adamtaranto Sep 14, 2024
824a5b9
Add tests for sum_counts attr
Adamtaranto Sep 14, 2024
8ba5166
Add __len__ method to get total unique kmer count.
Adamtaranto Sep 14, 2024
06e148b
Add tests for __len__ method.
Adamtaranto Sep 14, 2024
c697fa0
Add __getitem__ and __setitem__ methods
Adamtaranto Sep 14, 2024
5163e5f
Tests for __getitem__ and __setitem__ dunder methods.
Adamtaranto Sep 14, 2024
aeef6e7
Make KmerCountTable iterable
Adamtaranto Sep 14, 2024
9d95e78
Add tests for __iter__
Adamtaranto Sep 14, 2024
06b56fa
Add serde and flate2
Adamtaranto Sep 16, 2024
3c1faed
Add serialisation support
Adamtaranto Sep 16, 2024
91dba9c
Add serialisation tests
Adamtaranto Sep 16, 2024
1c3cfec
Merge branch 'main' into dev_serialisation
Adamtaranto Sep 18, 2024
626d83f
Merge branch 'main' into dev_serialisation
Adamtaranto Sep 20, 2024
496a0cb
Add niffler
Adamtaranto Sep 20, 2024
b167d16
sort imports
Adamtaranto Sep 20, 2024
6cf462d
Use niffler for gzip read/write
Adamtaranto Sep 20, 2024
9642592
Style fixes by Ruff
Adamtaranto Sep 20, 2024
fd3b202
bump serde
Adamtaranto Sep 20, 2024
4e97c7e
MRG: update serialization code (#49)
ctb Sep 23, 2024
e094186
Merge branch 'main' into dev_serialisation
Adamtaranto Sep 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ sourmash = "0.15.1"
anyhow = "1.0.89"
log = "0.4.22"
env_logger = "0.11.5"

# For JSON serialization/deserialization
serde = { version = "1.0", features = ["derive"] }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why version 1? latest is 1.0.210. (presumably dependabot will upgrade, just curious.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No good reason, just inherited from example code.

serde_json = "1.0"

# For Gzip compression/decompression
flate2 = "1.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest niffler, which allows sniffing/auto-determination of file formats. but we can backport that in if we need.

(I'm using it over in #10)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into niffler 👍

Only ever expecting gzipped JSON as input.

66 changes: 62 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
// Standard library imports
use std::collections::hash_map::IntoIter;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;

// External crate imports
use anyhow::{anyhow, Result};
use log::debug;
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use sourmash::encodings::HashFunctions;
use sourmash::signature::SeqToHashes;

use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use flate2::Compression;
use serde::{Deserialize, Serialize};
use std::io;

// Set version variable
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[pyclass]
#[derive(Serialize, Deserialize, Debug)]
struct KmerCountTable {
counts: HashMap<u64, u64>,
pub ksize: u8,
Expand Down Expand Up @@ -179,11 +189,59 @@ impl KmerCountTable {
Ok(to_remove.len() as u64)
}

// TODO: Serialize the KmerCountTable instance to a JSON string.
// Serialize the KmerCountTable as a JSON string
pub fn serialize_json(&self) -> PyResult<String> {
serde_json::to_string(&self)
.map_err(|e| PyValueError::new_err(format!("Serialization failed: {}", e)))
}

// Save the KmerCountTable to a Gzip-compressed JSON file
pub fn save(&self, filepath: String) -> PyResult<()> {
let serialized = self.serialize_json()?;

let path = Path::new(&filepath);
let file = File::create(path)
.map_err(|e| PyValueError::new_err(format!("File creation failed: {}", e)))?;
let mut encoder = GzEncoder::new(file, Compression::default());

encoder
.write_all(serialized.as_bytes())
.map_err(|e| PyValueError::new_err(format!("Failed to write to file: {}", e)))?;
encoder
.finish()
.map_err(|e| PyValueError::new_err(format!("Failed to finish compression: {}", e)))?;

Ok(())
}

#[staticmethod]
pub fn load(filepath: String) -> PyResult<Self> {
let path = Path::new(&filepath);
let file = File::open(path)
.map_err(|e| PyValueError::new_err(format!("File open failed: {}", e)))?;
let mut decoder = GzDecoder::new(file);
let mut decompressed_data = String::new();

// TODO: Compress JSON string with gzip and save to file
decoder
.read_to_string(&mut decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Decompression failed: {}", e)))?;

// TODO: Static method to load KmerCountTable from serialized JSON. Yield new object.
let loaded_table: KmerCountTable = serde_json::from_str(&decompressed_data)
.map_err(|e| PyValueError::new_err(format!("Deserialization failed: {}", e)))?;

// Check for version mismatch
if loaded_table.version != VERSION {
eprint!(
"Warning: Version mismatch: loaded version is {}, but current version is {}",
loaded_table.version, VERSION
);
io::stderr()
.flush()
.map_err(|e| PyIOError::new_err(e.to_string()))?; // Flush stderr
}

Ok(loaded_table)
}

// TODO: Add method "dump"
// Output tab delimited kmer:count pairs
Expand Down
84 changes: 84 additions & 0 deletions src/python/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import gzip
import json
import tempfile
import pytest

from oxli import KmerCountTable
from os import remove
from test_attr import get_version_from_cargo_toml

CURRENT_VERSION = get_version_from_cargo_toml()

@pytest.fixture
def sample_kmer_table():
"""Fixture that provides a sample KmerCountTable object."""
table = KmerCountTable(ksize=4)
table.count("AAAA")
table.count("TTTT")
return table

@pytest.fixture
def temp_file():
"""Fixture that provides a temporary file path for testing."""
with tempfile.NamedTemporaryFile(delete=False, suffix='.json.gz') as temp:
yield temp.name
# Remove the file after the test is done
remove(temp.name)

def test_serialize_json(sample_kmer_table):
"""
Test case for the `serialize_json` function.

This test verifies that the `serialize_json` function correctly serializes a
KmerCountTable object into a JSON string.
"""
# Serialize the KmerCountTable object to JSON
json_data = sample_kmer_table.serialize_json()

# Convert back to dict to verify correctness
json_dict = json.loads(json_data)

# Check that essential attributes exist
assert "counts" in json_dict, "Counts should be serialized."
assert json_dict["ksize"] == 4, "Ksize should be correctly serialized."
assert sample_kmer_table.version == json_dict["version"], "Version should be serialized."

def test_save_load_roundtrip(sample_kmer_table, temp_file):
"""
Test the save and load functionality.

This test saves a KmerCountTable object to a file, then loads it back and
verifies that the data in the loaded object matches the original.
"""
# Save the sample KmerCountTable to a Gzip file
sample_kmer_table.save(temp_file)

# Load the KmerCountTable from the file
loaded_table = KmerCountTable.load(temp_file)

# Verify that the loaded data matches the original
assert loaded_table.get("AAAA") == sample_kmer_table.get("AAAA"), "Counts should be preserved after loading."
assert loaded_table.get("TTTT") == sample_kmer_table.get("TTTT"), "Counts for reverse complement should be preserved."
assert list(loaded_table) == list(sample_kmer_table), "All records in same order."

def test_version_warning_on_load_stderr(sample_kmer_table, temp_file, capfd):
"""
Test that a warning is issued if the loaded object's version is different from the current Oxli version.

Uses pytest's capsys fixture to capture stderr output.
"""
# Save the table to a file
sample_kmer_table.save(temp_file)

# Mock the current version to simulate a version mismatch
mock_json = sample_kmer_table.serialize_json().replace(CURRENT_VERSION, "0.0.1")
with gzip.open(temp_file, 'wt') as f:
json.dump(json.loads(mock_json), f)

# Capture stderr output
loaded_table = KmerCountTable.load(temp_file)
captured = capfd.readouterr()

# Check stderr for the version mismatch warning
assert "Version mismatch" in captured.err
assert f"loaded version is 0.0.1, but current version is {CURRENT_VERSION}" in captured.err
Loading