Skip to content

Commit

Permalink
Add reasonable tests for metadata DB
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 29, 2024
1 parent 44a594d commit d0c3a74
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 20 deletions.
5 changes: 0 additions & 5 deletions sc2ts/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,6 @@ def append(self, alignments, show_progress=False):
self._flush(chunk)
bar.close()

def __contains__(self, key):
with self.env.begin() as txn:
val = txn.get(key.encode())
return val is not None

def __getitem__(self, key):
with self.env.begin() as txn:
val = txn.get(key.encode())
Expand Down
44 changes: 29 additions & 15 deletions sc2ts/metadata.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import logging
import sqlite3
import pathlib
Expand All @@ -13,14 +14,27 @@ def dict_factory(cursor, row):
return {key: value for key, value in zip(col_names, row)}


class MetadataDb:
class MetadataDb(collections.abc.Mapping):
def __init__(self, path):
uri = f"file:{path}"
uri += "?mode=ro"
self.uri = uri
self.conn = sqlite3.connect(uri, uri=True)
self.conn.row_factory = dict_factory

@staticmethod
def import_csv(csv_path, db_path):
df = pd.read_csv(csv_path, sep="\t")
db_path = pathlib.Path(db_path)
if db_path.exists():
db_path.unlink()
with sqlite3.connect(db_path) as conn:
df.to_sql("samples", conn, index=False)
conn.execute(
"CREATE UNIQUE INDEX [ix_samples_strain] on 'samples' ([strain]);"
)
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")

def __enter__(self):
return self

Expand All @@ -36,23 +50,23 @@ def __len__(self):
row = self.conn.execute(sql).fetchone()
return row["COUNT(*)"]

def __getitem__(self, key):
sql = "SELECT * FROM samples WHERE strain==?"
with self.conn:
result = self.conn.execute(sql, [key]).fetchone()
if result is None:
raise KeyError(f"strain {key} not in DB")
return result

def __iter__(self):
sql = "SELECT strain FROM samples"
with self.conn:
for result in self.conn.execute(sql):
yield result["strain"]

def close(self):
self.conn.close()

@staticmethod
def import_csv(csv_path, db_path):
df = pd.read_csv(
csv_path,
sep="\t",
)
db_path = pathlib.Path(db_path)
if db_path.exists():
db_path.unlink()
with sqlite3.connect(db_path) as conn:
df.to_sql("samples", conn, index=False)
conn.execute("CREATE INDEX [ix_samples_strain] on 'samples' ([strain]);")
conn.execute("CREATE INDEX [ix_samples_date] on 'samples' ([date]);")

def get(self, date):
sql = "SELECT * FROM samples WHERE date==?"
with self.conn:
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ def alignments_store(data_cache, alignments_fasta):
fasta = sc2ts.core.FastaReader(alignments_fasta)
a.append(fasta, show_progress=False)
return sc2ts.AlignmentStore(cache_path)

@pytest.fixture
def metadata_db(data_cache):
cache_path = data_cache / "metadata.db"
tsv_path = "tests/data/metadata.tsv"
if not cache_path.exists():
sc2ts.MetadataDb.import_csv(tsv_path, cache_path)
return sc2ts.MetadataDb(cache_path)
75 changes: 75 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import pandas as pd


class TestMetadataDb:
def test_known(self, metadata_db):
record = metadata_db["SRR11772659"]
assert record["strain"] == "SRR11772659"
assert record["date"] == "2020-01-19"
assert record["Viridian_pangolin"] == "A"

def test_missing_sequence(self, metadata_db):
# We include sequence that's not in the alignments DB
assert "ERR_MISSING" in metadata_db

def test_keys(self, metadata_db):
keys = list(metadata_db.keys())
assert "SRR11772659" in keys
assert len(set(keys)) == len(keys)
df = pd.read_csv("tests/data/metadata.tsv", sep="\t")
assert set(keys) == set(df["strain"])

def test_in(self, metadata_db):
assert "SRR11772659" in metadata_db
assert "DEFO_NOT_IN_DB" not in metadata_db

def test_get_all_days(self, metadata_db):
results = metadata_db.get_days()
assert results == [
"2020-01-01",
"2020-01-19",
"2020-01-24",
"2020-01-25",
"2020-01-28",
"2020-01-29",
"2020-01-30",
"2020-01-31",
"2020-02-01",
"2020-02-02",
"2020-02-03",
"2020-02-04",
"2020-02-05",
"2020-02-06",
"2020-02-07",
"2020-02-08",
"2020-02-09",
"2020-02-10",
"2020-02-11",
"2020-02-13",
]

def test_get_days_greater(self, metadata_db):
results = metadata_db.get_days("2020-02-06")
assert results == [
"2020-02-07",
"2020-02-08",
"2020-02-09",
"2020-02-10",
"2020-02-11",
"2020-02-13",
]

def test_get_days_none(self, metadata_db):
assert metadata_db.get_days("2022-02-06") == []

def test_get_first(self, metadata_db):
results = list(metadata_db.get("2020-01-01"))
assert len(results) == 1
assert results[0] == metadata_db["SRR14631544"]

def test_get_multi(self, metadata_db):
results = list(metadata_db.get("2020-02-11"))
assert len(results) == 2
for result in results:
assert result["date"] == "2020-02-11"

0 comments on commit d0c3a74

Please sign in to comment.