Skip to content

Commit

Permalink
Fixup tests
Browse files Browse the repository at this point in the history
realising there are multiple correct answers here
  • Loading branch information
jeromekelleher committed Sep 3, 2024
1 parent 86b8008 commit c7970e0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 37 deletions.
35 changes: 18 additions & 17 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches):
pkl_compressed,
)
data.append(args)
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
)
logger.debug(f"MatchDB insert: hmm_cost={hmm_cost[j]} {sample.summary()}")
# Batch insert, for efficiency.
with self.conn:
self.conn.executemany(sql, data)
Expand Down Expand Up @@ -150,11 +147,7 @@ def get(self, where_clause):
for row in self.conn.execute(sql):
pkl = row.pop("pickle")
sample = pickle.loads(bz2.decompress(pkl))
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDb got: {sample.strain} {sample.date} {pango} "
f"hmm_cost={row['hmm_cost']}"
)
logger.debug(f"MatchDb got: {sample.summary()} hmm_cost={row['hmm_cost']}")
# print(row)
yield sample

Expand Down Expand Up @@ -364,6 +357,18 @@ class Sample:
# def __str__(self):
# return f"{self.strain}: {self.path} + {self.mutations}"

def path_summary(self):
return ",".join(f"({seg.left}:{seg.right}, {seg.parent})" for seg in self.path)

def mutation_summary(self):
return "[" + ",".join(str(mutation) for mutation in self.mutations) + "]"

def summary(self):
pango = self.metadata.get("Viridian_pangolin", "Unknown")
return (f"{self.strain} {self.date} {pango} path={self.path_summary()} "
f"mutations({len(self.mutations)})={self.mutation_summary()}"
)

@property
def breakpoints(self):
breakpoints = [seg.left for seg in self.path]
Expand Down Expand Up @@ -415,9 +420,7 @@ def match_samples(
exceeding_threshold = []
for sample in run_batch:
cost = sample.get_hmm_cost(num_mismatches)
logger.debug(
f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}"
)
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
if cost > cost_threshold:
sample.path.clear()
sample.mutations.clear()
Expand All @@ -441,11 +444,9 @@ def match_samples(
show_progress=show_progress,
)
for sample in run_batch:
hmm_cost = sample.get_hmm_cost(num_mismatches)
cost = sample.get_hmm_cost(num_mismatches)
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
logger.debug(
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
)
logger.debug(f"Final HMM pass hmm_cost={cost} {sample.summary()}")
return samples


Expand Down Expand Up @@ -1439,7 +1440,7 @@ def get_closest_mutation(node, site_id):
sample.mutations.append(
MatchMutation(
site_id=site_id,
site_position=site_pos,
site_position=int(site_pos),
derived_state=derived_state,
inherited_state=inherited_state,
is_reversion=is_reversion,
Expand Down
36 changes: 16 additions & 20 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,20 +571,12 @@ def test_2020_02_02(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d
)
assert ts.num_samples == 26
assert np.sum(ts.nodes_time[ts.samples()] == 0) == 4
samples = {}
for u in ts.samples()[-4:]:
node = ts.node(u)
samples[node.metadata["strain"]] = node
smd = node.metadata["sc2ts"]
md = node.metadata
print(md["date"], md["strain"], len(smd["mutations"]))
# print(samples)
# print(fx_ts_map["2020-02-01"])
# print(ts)
# print(fx_ts_map["2020-02-02"])
ts.tables.assert_equals(fx_ts_map["2020-02-02"].tables, ignore_provenance=True)


@pytest.mark.parametrize("date", dates)
def test_date_metadata(self, fx_ts_map, date):
ts = fx_ts_map[date]
Expand All @@ -601,7 +593,11 @@ def test_date_validate(self, fx_ts_map, fx_alignment_store, date):

@pytest.mark.parametrize("date", dates[1:])
def test_node_mutation_counts(self, fx_ts_map, date):
# Basic check to make sure our fixtures are what we expect
# Basic check to make sure our fixtures are what we expect.
# NOTE: this is somewhat fragile as the numbers of nodes does change
# a little depending on the exact solution that the HMM choses, for
# example when there are multiple single-mutation matches at different
# sites.
ts = fx_ts_map[date]
expected = {
"2020-01-19": {"nodes": 3, "mutations": 3},
Expand All @@ -616,13 +612,13 @@ def test_node_mutation_counts(self, fx_ts_map, date):
"2020-02-03": {"nodes": 36, "mutations": 42},
"2020-02-04": {"nodes": 41, "mutations": 48},
"2020-02-05": {"nodes": 42, "mutations": 48},
"2020-02-06": {"nodes": 48, "mutations": 51},
"2020-02-07": {"nodes": 50, "mutations": 57},
"2020-02-08": {"nodes": 56, "mutations": 58},
"2020-02-09": {"nodes": 58, "mutations": 61},
"2020-02-10": {"nodes": 59, "mutations": 65},
"2020-02-11": {"nodes": 61, "mutations": 66},
"2020-02-13": {"nodes": 65, "mutations": 68},
"2020-02-06": {"nodes": 49, "mutations": 51},
"2020-02-07": {"nodes": 51, "mutations": 57},
"2020-02-08": {"nodes": 57, "mutations": 58},
"2020-02-09": {"nodes": 59, "mutations": 61},
"2020-02-10": {"nodes": 60, "mutations": 65},
"2020-02-11": {"nodes": 62, "mutations": 66},
"2020-02-13": {"nodes": 66, "mutations": 68},
}
assert ts.num_nodes == expected[date]["nodes"]
assert ts.num_mutations == expected[date]["mutations"]
Expand All @@ -635,9 +631,9 @@ def test_node_mutation_counts(self, fx_ts_map, date):
(13, "SRR11597132", 10),
(16, "SRR11597177", 10),
(41, "SRR11597156", 10),
(56, "SRR11597216", 1),
(59, "SRR11597207", 40),
(61, "ERR4205570", 57),
(57, "SRR11597216", 1),
(60, "SRR11597207", 40),
(62, "ERR4205570", 58),
],
)
def test_exact_matches(self, fx_ts_map, node, strain, parent):
Expand Down Expand Up @@ -697,7 +693,7 @@ class TestMatchingDetails:
# assert s.path[0].parent == 37

@pytest.mark.parametrize(
("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)]
("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)]
)
@pytest.mark.parametrize("num_mismatches", [2, 3, 4])
@pytest.mark.parametrize("precision", [0, 1, 2, 12])
Expand Down

0 comments on commit c7970e0

Please sign in to comment.