Skip to content

Commit

Permalink
Merge pull request #249 from jeromekelleher/test_metadata
Browse files Browse the repository at this point in the history
Test metadata
  • Loading branch information
jeromekelleher authored Aug 29, 2024
2 parents 62458c5 + d0c3a74 commit 0056a19
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 53 deletions.
5 changes: 0 additions & 5 deletions sc2ts/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def append(self, alignments, show_progress=False):
self._flush(chunk)
bar.close()

def __contains__(self, key):
with self.env.begin() as txn:
val = txn.get(key.encode())
return val is not None

def __getitem__(self, key):
with self.env.begin() as txn:
val = txn.get(key.encode())
Expand Down
44 changes: 29 additions & 15 deletions sc2ts/metadata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import logging
import sqlite3
import pathlib
Expand All @@ -13,14 +14,27 @@ def dict_factory(cursor, row):
return {key: value for key, value in zip(col_names, row)}


class MetadataDb:
class MetadataDb(collections.abc.Mapping):
def __init__(self, path):
uri = f"file:{path}"
uri += "?mode=ro"
self.uri = uri
self.conn = sqlite3.connect(uri, uri=True)
self.conn.row_factory = dict_factory

@staticmethod
def import_csv(csv_path, db_path):
df = pd.read_csv(csv_path, sep="\t")
db_path = pathlib.Path(db_path)
if db_path.exists():
db_path.unlink()
with sqlite3.connect(db_path) as conn:
df.to_sql("samples", conn, index=False)
conn.execute(
"CREATE UNIQUE INDEX [ix_samples_strain] on 'samples' ([strain]);"
)
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")

def __enter__(self):
return self

Expand All @@ -36,23 +50,23 @@ def __len__(self):
row = self.conn.execute(sql).fetchone()
return row["COUNT(*)"]

def __getitem__(self, key):
sql = "SELECT * FROM samples WHERE strain==?"
with self.conn:
result = self.conn.execute(sql, [key]).fetchone()
if result is None:
raise KeyError(f"strain {key} not in DB")
return result

def __iter__(self):
sql = "SELECT strain FROM samples"
with self.conn:
for result in self.conn.execute(sql):
yield result["strain"]

def close(self):
self.conn.close()

@staticmethod
def import_csv(csv_path, db_path):
df = pd.read_csv(
csv_path,
sep="\t",
)
db_path = pathlib.Path(db_path)
if db_path.exists():
db_path.unlink()
with sqlite3.connect(db_path) as conn:
df.to_sql("samples", conn, index=False)
conn.execute("CREATE INDEX [ix_samples_strain] on 'samples' ([strain]);")
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")

def get(self, date):
sql = "SELECT * FROM samples WHERE date==?"
with self.conn:
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pathlib
import shutil
import gzip

import pytest

import sc2ts


@pytest.fixture
def data_cache():
cache_path = pathlib.Path("tests/data/cache")
if not cache_path.exists():
cache_path.mkdir()
return cache_path


@pytest.fixture
def alignments_fasta(data_cache):
cache_path = data_cache / "alignments.fasta"
if not cache_path.exists():
with gzip.open("tests/data/alignments.fasta.gz") as src:
with open(cache_path, "wb") as dest:
shutil.copyfileobj(src, dest)
return cache_path


@pytest.fixture
def alignments_store(data_cache, alignments_fasta):
cache_path = data_cache / "alignments.db"
if not cache_path.exists():
with sc2ts.AlignmentStore(cache_path, "a") as a:
fasta = sc2ts.core.FastaReader(alignments_fasta)
a.append(fasta, show_progress=False)
return sc2ts.AlignmentStore(cache_path)

@pytest.fixture
def metadata_db(data_cache):
cache_path = data_cache / "metadata.db"
tsv_path = "tests/data/metadata.tsv"
if not cache_path.exists():
sc2ts.MetadataDb.import_csv(tsv_path, cache_path)
return sc2ts.MetadataDb(cache_path)
34 changes: 1 addition & 33 deletions tests/test_alignments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import pathlib
import shutil
import gzip

import numpy as np
import pytest
from numpy.testing import assert_array_equal
Expand All @@ -10,34 +6,6 @@
from sc2ts import core


@pytest.fixture
def data_cache():
cache_path = pathlib.Path("tests/data/cache")
if not cache_path.exists():
cache_path.mkdir()
return cache_path


@pytest.fixture
def alignments_fasta(data_cache):
cache_path = data_cache / "alignments.fasta"
if not cache_path.exists():
with gzip.open("tests/data/alignments.fasta.gz") as src:
with open(cache_path, "wb") as dest:
shutil.copyfileobj(src, dest)
return cache_path


@pytest.fixture
def alignments_store(data_cache, alignments_fasta):
cache_path = data_cache / "alignments.db"
if not cache_path.exists():
with sa.AlignmentStore(cache_path, "a") as a:
fasta = core.FastaReader(alignments_fasta)
a.append(fasta, show_progress=False)
return sa.AlignmentStore(cache_path)


class TestAlignmentsStore:
def test_info(self, alignments_store):
assert "contains" in str(alignments_store)
Expand Down Expand Up @@ -117,7 +85,7 @@ def test_lowercase_nucleotide_missing(self, hap):
[0, -2],
],
)
def test_examples(self, a):
def test_error__examples(self, a):
with pytest.raises(ValueError):
sa.decode_alignment(np.array(a))

Expand Down
75 changes: 75 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import pandas as pd


class TestMetadataDb:
def test_known(self, metadata_db):
record = metadata_db["SRR11772659"]
assert record["strain"] == "SRR11772659"
assert record["date"] == "2020-01-19"
assert record["Viridian_pangolin"] == "A"

def test_missing_sequence(self, metadata_db):
# We include sequence that's not in the alignments DB
assert "ERR_MISSING" in metadata_db

def test_keys(self, metadata_db):
keys = list(metadata_db.keys())
assert "SRR11772659" in keys
assert len(set(keys)) == len(keys)
df = pd.read_csv("tests/data/metadata.tsv", sep="\t")
assert set(keys) == set(df["strain"])

def test_in(self, metadata_db):
assert "SRR11772659" in metadata_db
assert "DEFO_NOT_IN_DB" not in metadata_db

def test_get_all_days(self, metadata_db):
results = metadata_db.get_days()
assert results == [
"2020-01-01",
"2020-01-19",
"2020-01-24",
"2020-01-25",
"2020-01-28",
"2020-01-29",
"2020-01-30",
"2020-01-31",
"2020-02-01",
"2020-02-02",
"2020-02-03",
"2020-02-04",
"2020-02-05",
"2020-02-06",
"2020-02-07",
"2020-02-08",
"2020-02-09",
"2020-02-10",
"2020-02-11",
"2020-02-13",
]

def test_get_days_greater(self, metadata_db):
results = metadata_db.get_days("2020-02-06")
assert results == [
"2020-02-07",
"2020-02-08",
"2020-02-09",
"2020-02-10",
"2020-02-11",
"2020-02-13",
]

def test_get_days_none(self, metadata_db):
assert metadata_db.get_days("2022-02-06") == []

def test_get_first(self, metadata_db):
results = list(metadata_db.get("2020-01-01"))
assert len(results) == 1
assert results[0] == metadata_db["SRR14631544"]

def test_get_multi(self, metadata_db):
results = list(metadata_db.get("2020-02-11"))
assert len(results) == 2
for result in results:
assert result["date"] == "2020-02-11"

0 comments on commit 0056a19

Please sign in to comment.