diff --git a/Cargo.lock b/Cargo.lock index f174719..562485c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + [[package]] name = "approx" version = "0.5.1" @@ -562,6 +568,7 @@ dependencies = [ name = "oxli" version = "0.1.0" dependencies = [ + "anyhow", "pyo3", "sourmash", ] @@ -681,6 +688,7 @@ version = "0.19.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e681a6cfdc4adcc93b4d3cf993749a4552018ee0a9b65fc0ccfad74352c72a38" dependencies = [ + "anyhow", "cfg-if", "indoc", "libc", diff --git a/Cargo.toml b/Cargo.toml index a9bbc7c..159a6db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,5 +9,6 @@ name = "oxli" crate-type = ["cdylib"] [dependencies] -pyo3 = "0.19.0" +pyo3 = { version="0.19.0", features = ["extension-module", "anyhow"] } sourmash = "0.15.1" +anyhow = "1.0.86" diff --git a/src/lib.rs b/src/lib.rs index ea30045..c5ef467 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,60 +1,107 @@ use pyo3::prelude::*; +use pyo3::exceptions::PyValueError; +// use rayon::prelude::*; +use anyhow::{Result, Error, anyhow}; use std::collections::HashMap; -use sourmash::sketch::nodegraph::Nodegraph; +// use sourmash::sketch::nodegraph::Nodegraph; use sourmash::_hash_murmur; +use sourmash::signature::SeqToHashes; +use sourmash::encodings::HashFunctions; + #[pyclass] struct KmerCountTable { - counts: HashMap, + counts: HashMap, + pub ksize: u8, } #[pymethods] impl KmerCountTable { #[new] - pub fn new() -> Self { - Self { counts: HashMap::new() } + pub fn new(ksize: u8) -> Self { + Self { counts: HashMap::new(), ksize } } - pub fn count(&mut self, kmer: String) -> PyResult { - let hashval = _hash_murmur(kmer.as_bytes(), 42); + fn hash_kmer(&self, kmer: String) -> Result { + if kmer.len() as u8 != self.ksize { + Err(anyhow!("wrong ksize")) + } else { + // mut? + let mut hashes = SeqToHashes::new(kmer.as_bytes(), + self.ksize.into(), + false, + false, + HashFunctions::Murmur64Dna, + 42); + + let mut hashval = hashes.next().unwrap(); + Ok(hashval?) + } + } - let mut count: usize = 1; + + pub fn count_hash(&mut self, hashval: u64) -> u64 { + let mut count: u64 = 1; if self.counts.contains_key(&hashval) { count = *self.counts.get(&hashval).unwrap(); count = count + 1; } self.counts.insert(hashval, count); - Ok(count) + count } - pub fn get(&self, kmer: String) -> PyResult { - let hashval = _hash_murmur(kmer.as_bytes(), 42); + pub fn count(&mut self, kmer: String) -> PyResult { + if kmer.len() as u8 != self.ksize { + Err(PyValueError::new_err("kmer size does not match count table ksize")) + } else { + let hashval = _hash_murmur(kmer.as_bytes(), 42); + let count = self.count_hash(hashval); + Ok(count) + } + } - let count = match self.counts.get(&hashval) { - Some(count) => count, - None => &(0 as usize) - }; - Ok(*count) + pub fn get(&self, kmer: String) -> PyResult { + if kmer.len() as u8 != self.ksize { + Err(PyValueError::new_err("kmer size does not match count table ksize")) + } else { + let hashval = self.hash_kmer(kmer).unwrap(); + + let count = match self.counts.get(&hashval) { + Some(count) => count, + None => &0 + }; + Ok(*count) + } } -} -/// Formats the sum of two numbers as string. -#[pyfunction] -fn sum_as_string(a: String) -> PyResult { - let mut ng: Nodegraph = Nodegraph::with_tables(23, 6, 3); + // Consume this DNA strnig. Return number of k-mers consumed. + pub fn consume(&mut self, seq: String) -> PyResult { + let hashes = SeqToHashes::new(seq.as_bytes(), + self.ksize.into(), + false, + false, + HashFunctions::Murmur64Dna, + 42); - let hashval = _hash_murmur(a.as_bytes(), 42); - ng.count(hashval); - Ok(ng.get(hashval)) + let mut n = 0; + for hash_value in hashes { + match hash_value { + Ok(0) => continue, + Ok(x) => { self.count_hash(x); () } + Err(err) => (), + } + n += 1; + } + + Ok(n) + } } -/// A Python module implemented in Rust. #[pymodule] fn oxli(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; m.add_class::()?; Ok(()) } diff --git a/src/python/tests/test_basic.py b/src/python/tests/test_basic.py index dc246dc..7c33f5b 100644 --- a/src/python/tests/test_basic.py +++ b/src/python/tests/test_basic.py @@ -1,9 +1,39 @@ +import pytest import oxli def test_simple(): - cg = oxli.KmerCountTable() + cg = oxli.KmerCountTable(4) kmer = "ATCG" assert cg.get(kmer) == 0 assert cg.count(kmer) == 1 assert cg.get(kmer) == 1 + + +def test_wrong_ksize(): + cg = oxli.KmerCountTable(3) + kmer = "ATCG" + + with pytest.raises(ValueError): + cg.count(kmer) + + with pytest.raises(ValueError): + cg.get(kmer) + + +def test_consume(): + cg = oxli.KmerCountTable(4) + kmer = "ATCG" + + assert cg.consume(kmer) == 1 + assert cg.get("ATCG") == 1 + + +def test_consume_2(): + cg = oxli.KmerCountTable(4) + seq = "ATCGG" + + assert cg.consume(seq) == 2 + assert cg.get("ATCG") == 1 + assert cg.get("TCGG") == 1 + assert cg.get("CCGA") == 1 # reverse complement!