Skip to content

Commit

Permalink
Merge pull request #255 from jeromekelleher/add-some-matching-tests
Browse files Browse the repository at this point in the history
Add some matching tests
  • Loading branch information
jeromekelleher authored Aug 30, 2024
2 parents 4a35bec + bad0019 commit e996bfd
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 156 deletions.
123 changes: 50 additions & 73 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,73 +433,6 @@ def asdict(self):
# last_ts = ts


def preprocess(
date,
*,
base_ts,
metadata_db,
alignment_store,
max_daily_samples=None,
show_progress=False,
):
samples = []
metadata_matches = list(metadata_db.get(date))

if len(metadata_matches) == 0:
logger.warn(f"Zero metadata matches for {date}")
return []

if date.endswith("12-31"):
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
return []

# TODO implement this.
assert max_daily_samples is None

keep_sites = base_ts.sites_position.astype(int)
problematic_sites = core.get_problematic_sites()
samples = []

with tqdm.tqdm(
metadata_matches,
desc=f"Preprocess:{date}",
disable=not show_progress,
) as bar:
for md in bar:
strain = md["strain"]
try:
alignment = alignment_store[strain]
except KeyError:
logger.debug(f"No alignment stored for {strain}")
continue

sample = Sample(strain, date, metadata=md)
ma = alignments.encode_and_mask(alignment)
# Always mask the problematic_sites as well. We need to do this
# for follow-up matching to inspect recombinants, as tsinfer
# needs us to keep all sites in the table when doing mirrored
# coordinates.
ma.alignment[problematic_sites] = -1
sample.alignment_qc = ma.qc_summary()
sample.masked_sites = ma.masked_sites
sample.alignment = ma.alignment[keep_sites]
samples.append(sample)
num_Ns = ma.original_base_composition.get("N", 0)
non_nuc_counts = dict(ma.original_base_composition)
for nuc in "ACGT":
del non_nuc_counts[nuc]
counts = ",".join(
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
)
num_masked = len(ma.masked_sites)
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")

logger.info(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)
return samples


def match_samples(
date,
samples,
Expand Down Expand Up @@ -563,6 +496,47 @@ def check_base_ts(ts):
assert len(sc2ts_md["samples_strain"]) == ts.num_samples


def preprocess(samples_md, base_ts, date, alignment_store, show_progress=False):
keep_sites = base_ts.sites_position.astype(int)
problematic_sites = core.get_problematic_sites()

samples = []
with tqdm.tqdm(
samples_md,
desc=f"Preprocess",
disable=not show_progress,
) as bar:
for md in bar:
strain = md["strain"]
try:
alignment = alignment_store[strain]
except KeyError:
logger.debug(f"No alignment stored for {strain}")
continue
sample = Sample(strain, date, metadata=md)
ma = alignments.encode_and_mask(alignment)
# Always mask the problematic_sites as well. We need to do this
# for follow-up matching to inspect recombinants, as tsinfer
# needs us to keep all sites in the table when doing mirrored
# coordinates.
ma.alignment[problematic_sites] = -1
sample.alignment_qc = ma.qc_summary()
sample.masked_sites = ma.masked_sites
sample.alignment = ma.alignment[keep_sites]
samples.append(sample)
num_Ns = ma.original_base_composition.get("N", 0)
non_nuc_counts = dict(ma.original_base_composition)
for nuc in "ACGT":
del non_nuc_counts[nuc]
counts = ",".join(
f"{key}={count}" for key, count in sorted(non_nuc_counts.items())
)
num_masked = len(ma.masked_sites)
logger.debug(f"Mask {strain}: masked={num_masked} {counts}")

return samples


def extend(
*,
alignment_store,
Expand Down Expand Up @@ -594,19 +568,22 @@ def extend(
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
)

metadata_matches = list(metadata_db.get(date))
# TODO implement this.
assert max_daily_samples is None

samples = preprocess(
date,
metadata_db=metadata_db,
alignment_store=alignment_store,
base_ts=base_ts,
max_daily_samples=max_daily_samples,
show_progress=show_progress,
metadata_matches, base_ts, date, alignment_store, show_progress=show_progress
)

if len(samples) == 0:
logger.warning(f"Nothing to do for {date}")
return base_ts

logger.info(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)

match_samples(
date,
samples,
Expand Down
42 changes: 34 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import pathlib
import shutil
import gzip
import tskit

import pytest

import sc2ts


@pytest.fixture
def data_cache():
def fx_data_cache():
cache_path = pathlib.Path("tests/data/cache")
if not cache_path.exists():
cache_path.mkdir()
return cache_path


@pytest.fixture
def alignments_fasta(data_cache):
cache_path = data_cache / "alignments.fasta"
def fx_alignments_fasta(fx_data_cache):
cache_path = fx_data_cache / "alignments.fasta"
if not cache_path.exists():
with gzip.open("tests/data/alignments.fasta.gz") as src:
with open(cache_path, "wb") as dest:
Expand All @@ -26,18 +27,43 @@ def alignments_fasta(data_cache):


@pytest.fixture
def alignments_store(data_cache, alignments_fasta):
cache_path = data_cache / "alignments.db"
def fx_alignment_store(fx_data_cache, fx_alignments_fasta):
cache_path = fx_data_cache / "alignments.db"
if not cache_path.exists():
with sc2ts.AlignmentStore(cache_path, "a") as a:
fasta = sc2ts.core.FastaReader(alignments_fasta)
fasta = sc2ts.core.FastaReader(fx_alignments_fasta)
a.append(fasta, show_progress=False)
return sc2ts.AlignmentStore(cache_path)

@pytest.fixture
def metadata_db(data_cache):
cache_path = data_cache / "metadata.db"
def fx_metadata_db(fx_data_cache):
cache_path = fx_data_cache / "metadata.db"
tsv_path = "tests/data/metadata.tsv"
if not cache_path.exists():
sc2ts.MetadataDb.import_csv(tsv_path, cache_path)
return sc2ts.MetadataDb(cache_path)


@pytest.fixture
def fx_ts_2020_02_10(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store):
target_date = "2020-02-10"
cache_path = fx_data_cache / f"{target_date}.ts"
if not cache_path.exists():
last_ts = sc2ts.initial_ts()
match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db")
for date in fx_metadata_db.date_sample_counts():
print("INFERRING", date)
last_ts = sc2ts.extend(
alignment_store=fx_alignment_store,
metadata_db=fx_metadata_db,
base_ts=last_ts,
date=date,
match_db=match_db,
min_group_size=2,
)
if date == target_date:
break
last_ts.dump(cache_path)
return tskit.load(cache_path)


32 changes: 16 additions & 16 deletions tests/test_alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@


class TestAlignmentsStore:
def test_info(self, alignments_store):
assert "contains" in str(alignments_store)
def test_info(self, fx_alignment_store):
assert "contains" in str(fx_alignment_store)

def test_len(self, alignments_store):
assert len(alignments_store) == 55
def test_len(self, fx_alignment_store):
assert len(fx_alignment_store) == 55

def test_fetch_known(self, alignments_store):
a = alignments_store["SRR11772659"]
def test_fetch_known(self, fx_alignment_store):
a = fx_alignment_store["SRR11772659"]
assert a.shape == (core.REFERENCE_SEQUENCE_LENGTH,)
assert a[0] == "X"
assert a[1] == "N"
assert a[-1] == "N"

def test_keys(self, alignments_store):
keys = list(alignments_store.keys())
assert len(keys) == len(alignments_store)
def test_keys(self, fx_alignment_store):
keys = list(fx_alignment_store.keys())
assert len(keys) == len(fx_alignment_store)
assert "SRR11772659" in keys

def test_in(self, alignments_store):
assert "SRR11772659" in alignments_store
assert "NOT_IN_STORE" not in alignments_store
def test_in(self, fx_alignment_store):
assert "SRR11772659" in fx_alignment_store
assert "NOT_IN_STORE" not in fx_alignment_store


def test_get_gene_coordinates():
Expand Down Expand Up @@ -89,8 +89,8 @@ def test_error__examples(self, a):
with pytest.raises(ValueError):
sa.decode_alignment(np.array(a))

def test_encode_real(self, alignments_store):
h = alignments_store["SRR11772659"]
def test_encode_real(self, fx_alignment_store):
h = fx_alignment_store["SRR11772659"]
a = sa.encode_alignment(h)
assert a[0] == -1
assert a[-1] == -1
Expand Down Expand Up @@ -145,8 +145,8 @@ def test_bad_window_size(self, w):


class TestEncodeAndMask:
def test_known(self, alignments_store):
a = alignments_store["SRR11772659"]
def test_known(self, fx_alignment_store):
a = fx_alignment_store["SRR11772659"]
ma = sa.encode_and_mask(a)
assert ma.original_base_composition == {
"T": 9566,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def test_additional_problematic_sites(self, tmp_path, additional):


class TestListDates:
def test_defaults(self, metadata_db):
def test_defaults(self, fx_metadata_db):
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"list-dates {metadata_db.path}",
f"list-dates {fx_metadata_db.path}",
catch_exceptions=False,
)
assert result.exit_code == 0
Expand All @@ -78,11 +78,11 @@ def test_defaults(self, metadata_db):
"2020-02-13",
]

def test_counts(self, metadata_db):
def test_counts(self, fx_metadata_db):
runner = ct.CliRunner(mix_stderr=False)
result = runner.invoke(
cli.cli,
f"list-dates {metadata_db.path} --counts",
f"list-dates {fx_metadata_db.path} --counts",
catch_exceptions=False,
)
assert result.exit_code == 0
Expand Down
Loading

0 comments on commit e996bfd

Please sign in to comment.