From 7d5a5be3109e3c0816da34598f55d4c39d544e49 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 30 Aug 2024 17:00:42 +0100 Subject: [PATCH 1/7] Debugging --- sc2ts/inference.py | 104 +++++++++++++++------------------------------ 1 file changed, 35 insertions(+), 69 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 9854d15..7893117 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -388,51 +388,6 @@ def asdict(self): } -# def daily_extend( -# *, -# alignment_store, -# metadata_db, -# base_ts, -# match_db, -# num_mismatches=None, -# max_hmm_cost=None, -# min_group_size=None, -# num_past_days=None, -# show_progress=False, -# max_submission_delay=None, -# max_daily_samples=None, -# num_threads=None, -# precision=None, -# rng=None, -# excluded_sample_dir=None, -# ): -# assert num_past_days is None -# assert max_submission_delay is None - -# start_day = last_date(base_ts) - -# last_ts = base_ts -# for date in metadata_db.get_days(start_day): -# ts = extend( -# alignment_store=alignment_store, -# metadata_db=metadata_db, -# date=date, -# base_ts=last_ts, -# match_db=match_db, -# num_mismatches=num_mismatches, -# max_hmm_cost=max_hmm_cost, -# min_group_size=min_group_size, -# show_progress=show_progress, -# max_submission_delay=max_submission_delay, -# max_daily_samples=max_daily_samples, -# num_threads=num_threads, -# precision=precision, -# ) -# yield ts, date - -# last_ts = ts - - def match_samples( date, samples, @@ -449,29 +404,15 @@ def match_samples( # Default to no recombination num_mismatches = 1000 - match_tsinfer( - samples=samples, - ts=base_ts, - num_mismatches=num_mismatches, - precision=2, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=mirror_coordinates, - ) - samples_to_rerun = [] - for sample in samples: - hmm_cost = sample.get_hmm_cost(num_mismatches) - logger.debug( - f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}" - ) - if hmm_cost >= 2: - sample.path.clear() - sample.mutations.clear() - samples_to_rerun.append(sample) + remaining_samples = samples + # FIXME Something wrong here, we don't seem to get precisely the same + # ARG for some reason. Need to track it down + # Also: should only run the things at low precision that have that HMM cost. + # Start out by setting everything to have 0 mutations and work up from there. - if len(samples_to_rerun) > 0: + for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: match_tsinfer( - samples=samples_to_rerun, + samples=remaining_samples, ts=base_ts, num_mismatches=num_mismatches, precision=precision, @@ -479,11 +420,34 @@ def match_samples( show_progress=show_progress, mirror_coordinates=mirror_coordinates, ) - for sample in samples_to_rerun: + samples_to_rerun = [] + for sample in remaining_samples: hmm_cost = sample.get_hmm_cost(num_mismatches) + # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}") logger.debug( - f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" + f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}" ) + if hmm_cost > cost: + sample.path.clear() + sample.mutations.clear() + samples_to_rerun.append(sample) + remaining_samples = samples_to_rerun + + match_tsinfer( + samples=samples_to_rerun, + ts=base_ts, + num_mismatches=num_mismatches, + precision=12, + num_threads=num_threads, + show_progress=show_progress, + mirror_coordinates=mirror_coordinates, + ) + for sample in samples_to_rerun: + hmm_cost = sample.get_hmm_cost(num_mismatches) + # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") + logger.debug( + f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" + ) match_db.add(samples, date, num_mismatches) @@ -855,7 +819,7 @@ def solve_num_mismatches(ts, k): # for the optimal value of this parameter such that the magnitude of the # values within the HMM are as large as possible (so that we can truncate # usefully). - mu = 1e-3 + mu = 1e-2 denom = (1 - mu) ** k + (n - 1) * mu**k r = n * mu**k / denom assert mu < 0.5 @@ -1312,6 +1276,8 @@ def match_tsinfer( show_progress=False, mirror_coordinates=False, ): + if len(samples) == 0: + return genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T input_ts = ts if mirror_coordinates: From 23ad2d3fd1a5c8c31ac80f328333cb026c09e9f9 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 15:11:05 +0100 Subject: [PATCH 2/7] Impove match testing and infrastructure Add metadata to exact match samples Closes #238 --- run.sh | 6 +- sc2ts/inference.py | 79 ++++++++-------- tests/conftest.py | 45 +++++++--- tests/test_inference.py | 193 ++++++++++++++++++++++++++++++---------- 4 files changed, 221 insertions(+), 102 deletions(-) diff --git a/run.sh b/run.sh index d6dfb14..983e9df 100755 --- a/run.sh +++ b/run.sh @@ -9,7 +9,7 @@ num_threads=8 # Paths datadir=testrun -run_id=tmp-dev +run_id=tmp-dev-hp # run_id=upgma-mds-$max_daily_samples-md-$max_submission_delay-mm-$mismatches resultsdir=results/$run_id results_prefix=$resultsdir/$run_id- @@ -17,9 +17,9 @@ logfile=logs/$run_id.log alignments=$datadir/alignments.db metadata=$datadir/metadata.db -matches=$resultsdir/matces.db +matches=$resultsdir/matches.db -dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31` +dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14` echo $dates options="--num-threads $num_threads -vv -l $logfile " diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 7893117..b781a81 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -404,35 +404,36 @@ def match_samples( # Default to no recombination num_mismatches = 1000 - remaining_samples = samples # FIXME Something wrong here, we don't seem to get precisely the same # ARG for some reason. Need to track it down # Also: should only run the things at low precision that have that HMM cost. # Start out by setting everything to have 0 mutations and work up from there. - for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: - match_tsinfer( - samples=remaining_samples, - ts=base_ts, - num_mismatches=num_mismatches, - precision=precision, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=mirror_coordinates, - ) - samples_to_rerun = [] - for sample in remaining_samples: - hmm_cost = sample.get_hmm_cost(num_mismatches) - # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}") - logger.debug( - f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}" - ) - if hmm_cost > cost: - sample.path.clear() - sample.mutations.clear() - samples_to_rerun.append(sample) - remaining_samples = samples_to_rerun - + # remaining_samples = samples + # for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: + # match_tsinfer( + # samples=remaining_samples, + # ts=base_ts, + # num_mismatches=num_mismatches, + # precision=precision, + # num_threads=num_threads, + # show_progress=show_progress, + # mirror_coordinates=mirror_coordinates, + # ) + # samples_to_rerun = [] + # for sample in remaining_samples: + # hmm_cost = sample.get_hmm_cost(num_mismatches) + # # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}") + # logger.debug( + # f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}" + # ) + # if hmm_cost > cost: + # sample.path.clear() + # sample.mutations.clear() + # samples_to_rerun.append(sample) + # remaining_samples = samples_to_rerun + + samples_to_rerun = samples match_tsinfer( samples=samples_to_rerun, ts=base_ts, @@ -605,6 +606,18 @@ def update_top_level_metadata(ts, date): return tables.tree_sequence() +def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, time=0): + metadata = { + **sample.metadata, + "sc2ts": { + "qc": sample.alignment_qc, + "path": [x.asdict() for x in sample.path], + "mutations": [x.asdict() for x in sample.mutations], + }, + } + return tables.nodes.add_row(flags=flags, time=time, metadata=metadata) + + def match_path_ts(samples, ts, path, reversions): """ Given the specified list of samples with equal copying paths, @@ -623,17 +636,7 @@ def match_path_ts(samples, ts, path, reversions): ) for sample in samples: assert sample.path == path - metadata = { - **sample.metadata, - "sc2ts": { - "qc": sample.alignment_qc, - "path": [x.asdict() for x in sample.path], - "mutations": [x.asdict() for x in sample.mutations], - }, - } - node_id = tables.nodes.add_row( - flags=tskit.NODE_IS_SAMPLE, time=0, metadata=metadata - ) + node_id = add_sample_to_tables(sample, tables) tables.edges.add_row(0, ts.sequence_length, parent=0, child=node_id) for mut in sample.mutations: if mut.site_id not in site_id_map: @@ -671,10 +674,10 @@ def add_exact_matches(match_db, ts, date): for sample in samples: assert len(sample.path) == 1 assert len(sample.mutations) == 0 - node_id = tables.nodes.add_row( + node_id = add_sample_to_tables( + sample, + tables, flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH, - time=0, - metadata=sample.metadata, ) parent = sample.path[0].parent logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}") diff --git a/tests/conftest.py b/tests/conftest.py index 29834c5..76bf15d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,6 +35,7 @@ def fx_alignment_store(fx_data_cache, fx_alignments_fasta): a.append(fasta, show_progress=False) return sc2ts.AlignmentStore(cache_path) + @pytest.fixture def fx_metadata_db(fx_data_cache): cache_path = fx_data_cache / "metadata.db" @@ -44,26 +45,46 @@ def fx_metadata_db(fx_data_cache): return sc2ts.MetadataDb(cache_path) +# TODO make this a session fixture cacheing the tree sequences. @pytest.fixture -def fx_ts_2020_02_10(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store): - target_date = "2020-02-10" - cache_path = fx_data_cache / f"{target_date}.ts" +def fx_ts_map(tmp_path, fx_data_cache, fx_metadata_db, fx_alignment_store): + dates = [ + "2020-01-01", + "2020-01-19", + "2020-01-24", + "2020-01-25", + "2020-01-28", + "2020-01-29", + "2020-01-30", + "2020-01-31", + "2020-02-01", + "2020-02-02", + "2020-02-03", + "2020-02-04", + "2020-02-05", + "2020-02-06", + "2020-02-07", + "2020-02-08", + "2020-02-09", + "2020-02-10", + "2020-02-11", + "2020-02-13", + ] + cache_path = fx_data_cache / f"{dates[-1]}.ts" if not cache_path.exists(): last_ts = sc2ts.initial_ts() match_db = sc2ts.MatchDb.initialise(tmp_path / "match.db") - for date in fx_metadata_db.date_sample_counts(): - print("INFERRING", date) + for date in dates: last_ts = sc2ts.extend( alignment_store=fx_alignment_store, metadata_db=fx_metadata_db, base_ts=last_ts, date=date, match_db=match_db, - min_group_size=2, ) - if date == target_date: - break - last_ts.dump(cache_path) - return tskit.load(cache_path) - - + print( + f"INFERRED {date} nodes={last_ts.num_nodes} mutations={last_ts.num_mutations}" + ) + cache_path = fx_data_cache / f"{date}.ts" + last_ts.dump(cache_path) + return {date: tskit.load(fx_data_cache / f"{date}.ts") for date in dates} diff --git a/tests/test_inference.py b/tests/test_inference.py index fa70e0f..6dda621 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -486,7 +486,30 @@ def test_high_recomb_mutation(self): class TestRealData: - def test_first_day(self, tmp_path, fx_alignment_store, fx_metadata_db): + dates = [ + "2020-01-01", + "2020-01-19", + "2020-01-24", + "2020-01-25", + "2020-01-28", + "2020-01-29", + "2020-01-30", + "2020-01-31", + "2020-02-01", + "2020-02-02", + "2020-02-03", + "2020-02-04", + "2020-02-05", + "2020-02-06", + "2020-02-07", + "2020-02-08", + "2020-02-09", + "2020-02-10", + "2020-02-11", + "2020-02-13", + ] + + def test_first_day(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db): ts = sc2ts.extend( alignment_store=fx_alignment_store, metadata_db=fx_metadata_db, @@ -536,24 +559,127 @@ def test_first_day(self, tmp_path, fx_alignment_store, fx_metadata_db): "original_md5": "e96feaa72c4f4baba73c2e147ede7502", } - def test_2020_02_10_metadata(self, fx_ts_2020_02_10): - ts = fx_ts_2020_02_10 - assert ts.metadata["sc2ts"]["date"] == "2020-02-10" + assert ts.tables.equals(fx_ts_map["2020-01-19"].tables, ignore_provenance=True) + + @pytest.mark.parametrize("date", dates) + def test_date_metadata(self, fx_ts_map, date): + ts = fx_ts_map[date] + assert ts.metadata["sc2ts"]["date"] == date samples_strain = [ts.node(u).metadata["strain"] for u in ts.samples()] assert ts.metadata["sc2ts"]["samples_strain"] == samples_strain # print(ts.tables.mutations) # print(ts.draw_text()) + @pytest.mark.parametrize("date", dates) + def test_date_validate(self, fx_ts_map, fx_alignment_store, date): + ts = fx_ts_map[date] + sc2ts.validate(ts, fx_alignment_store) + + @pytest.mark.parametrize("date", dates[1:]) + def test_node_mutation_counts(self, fx_ts_map, date): + # Basic check to make sure our fixtures are what we expect + ts = fx_ts_map[date] + expected = { + "2020-01-19": {"nodes": 3, "mutations": 3}, + "2020-01-24": {"nodes": 6, "mutations": 4}, + "2020-01-25": {"nodes": 11, "mutations": 6}, + "2020-01-28": {"nodes": 13, "mutations": 11}, + "2020-01-29": {"nodes": 16, "mutations": 15}, + "2020-01-30": {"nodes": 22, "mutations": 19}, + "2020-01-31": {"nodes": 23, "mutations": 21}, + "2020-02-01": {"nodes": 28, "mutations": 27}, + "2020-02-02": {"nodes": 34, "mutations": 36}, + "2020-02-03": {"nodes": 37, "mutations": 42}, + "2020-02-04": {"nodes": 42, "mutations": 48}, + "2020-02-05": {"nodes": 43, "mutations": 48}, + "2020-02-06": {"nodes": 49, "mutations": 51}, + "2020-02-07": {"nodes": 51, "mutations": 57}, + "2020-02-08": {"nodes": 57, "mutations": 58}, + "2020-02-09": {"nodes": 59, "mutations": 61}, + "2020-02-10": {"nodes": 60, "mutations": 65}, + "2020-02-11": {"nodes": 62, "mutations": 66}, + "2020-02-13": {"nodes": 66, "mutations": 68}, + } + assert ts.num_nodes == expected[date]["nodes"] + assert ts.num_mutations == expected[date]["mutations"] + + @pytest.mark.parametrize( + ["node", "strain", "parent"], + [ + (6, "SRR11397726", 5), + (7, "SRR11397729", 5), + (13, "SRR11597132", 10), + (16, "SRR11597177", 10), + (42, "SRR11597156", 10), + (57, "SRR11597216", 1), + (60, "SRR11597207", 41), + (62, "ERR4205570", 58), + ], + ) + def test_exact_matches(self, fx_ts_map, node, strain, parent): + ts = fx_ts_map[self.dates[-1]] + x = ts.node(node) + assert x.flags == (tskit.NODE_IS_SAMPLE | sc2ts.core.NODE_IS_EXACT_MATCH) + md = x.metadata + assert md["strain"] == strain + sc2ts_md = md["sc2ts"] + assert len(sc2ts_md["path"]) == 1 + assert len(sc2ts_md["mutations"]) == 0 + assert sc2ts_md["path"][0] == { + "parent": parent, + "left": 0, + "right": ts.sequence_length, + } + edges = np.where(ts.edges_child == node)[0] + assert len(edges) == 1 + e = edges[0] + assert ts.edges_parent[e] == parent + assert ts.edges_left[e] == 0 + assert ts.edges_right[e] == ts.sequence_length + assert np.sum(ts.mutations_node == node) == 0 + class TestMatchingDetails: + # # @pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4]) + # # @pytest.mark.parametrize("precision", [0, 1, 2, 12]) + # def test_problematic_match( + # self, + # fx_ts_map, + # fx_alignment_store, + # fx_metadata_db, + # # strain, + # # parent, + # # num_mismatches, + # # precision, + # ): + # ts = fx_ts_map["2020-02-05"] + # strain = "SRR11597178" + # samples = sc2ts.preprocess( + # [fx_metadata_db[strain]], ts, "2020-02-06", fx_alignment_store + # ) + # sc2ts.match_tsinfer( + # samples=samples, + # ts=ts, + # num_mismatches=3, + # precision=1, + # num_threads=1, + # ) + # s = samples[0] + # # assert len(s.mutations) == 0 + # assert len(s.path) == 1 + # print(s.path) + # print("num utations =", len(s.mutations)) + # # print(s.metadata["sc2ts"]) + # assert s.path[0].parent == 37 + @pytest.mark.parametrize( - ("strain", "parent"), [("SRR11597207", 42), ("ERR4205570", 62)] + ("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)] ) @pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4]) @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_exact_matches( self, - fx_ts_2020_02_10, + fx_ts_map, fx_alignment_store, fx_metadata_db, strain, @@ -561,12 +687,13 @@ def test_exact_matches( num_mismatches, precision, ): + ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store + [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) sc2ts.match_tsinfer( samples=samples, - ts=fx_ts_2020_02_10, + ts=ts, num_mismatches=num_mismatches, precision=precision, num_threads=0, @@ -578,13 +705,13 @@ def test_exact_matches( @pytest.mark.parametrize( ("strain", "parent", "position", "derived_state"), - [("SRR11597218", 10, 289, "T"), ("ERR4206593", 62, 26994, "T")], + [("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")], ) @pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4]) @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_one_mismatch( self, - fx_ts_2020_02_10, + fx_ts_map, fx_alignment_store, fx_metadata_db, strain, @@ -594,12 +721,13 @@ def test_one_mismatch( num_mismatches, precision, ): + ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store + [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) sc2ts.match_tsinfer( samples=samples, - ts=fx_ts_2020_02_10, + ts=ts, num_mismatches=num_mismatches, precision=precision, num_threads=0, @@ -615,19 +743,20 @@ def test_one_mismatch( @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_two_mismatches( self, - fx_ts_2020_02_10, + fx_ts_map, fx_alignment_store, fx_metadata_db, num_mismatches, precision, ): strain = "ERR4204459" + ts = fx_ts_map["2020-02-10"] samples = sc2ts.preprocess( - [fx_metadata_db[strain]], fx_ts_2020_02_10, "2020-02-20", fx_alignment_store + [fx_metadata_db[strain]], ts, "2020-02-20", fx_alignment_store ) sc2ts.match_tsinfer( samples=samples, - ts=fx_ts_2020_02_10, + ts=ts, num_mismatches=num_mismatches, precision=precision, num_threads=0, @@ -638,37 +767,3 @@ def test_two_mismatches( assert len(s.mutations) == 2 # assert s.mutations[0].site_position == position # assert s.mutations[0].derived_state == derived_state - - -# def test_stuff( -# self, -# fx_ts_2020_02_10, -# fx_alignment_store, -# fx_metadata_db): - -# # date = "2020-02-11" # 2 samples -# date = "2020-02-13" # 4 samples - -# # metadata_matches = -# samples = sc2ts.preprocess( -# list(fx_metadata_db.get(date)), -# fx_ts_2020_02_10, -# date, fx_alignment_store -# ) -# # print(samples) - -# num_mismatches = 3 -# sc2ts.match_tsinfer( -# samples=samples, -# ts=fx_ts_2020_02_10, -# num_mismatches=3, -# precision=12, -# num_threads=0, -# ) -# for sample in samples: -# print( -# sample.strain, -# sample.get_hmm_cost(num_mismatches), -# sample.path[0].parent, -# len(sample.mutations), -# ) From ac8280e309b2ea5ebd257b7cde86168c76fc42b7 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 16:14:12 +0100 Subject: [PATCH 3/7] Intermediate update with identical result at full precision --- sc2ts/inference.py | 107 ++++++++++++++++++++++------------------ tests/test_inference.py | 28 ++++++++++- 2 files changed, 84 insertions(+), 51 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index b781a81..2781627 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -392,22 +392,51 @@ def match_samples( date, samples, *, - match_db, base_ts, num_mismatches=None, show_progress=False, num_threads=None, - precision=None, - mirror_coordinates=False, ): - if num_mismatches is None: - # Default to no recombination - num_mismatches = 1000 - - # FIXME Something wrong here, we don't seem to get precisely the same - # ARG for some reason. Need to track it down - # Also: should only run the things at low precision that have that HMM cost. - # Start out by setting everything to have 0 mutations and work up from there. + # First pass, compute the matches at precision=0. + # precision = 0 + # match_tsinfer( + # samples=samples, + # ts=base_ts, + # num_mismatches=num_mismatches, + # precision=precision, + # num_threads=num_threads, + # show_progress=show_progress, + # ) + + # cost_threshold = 1 + # rerun_batch = [] + # for sample in samples: + # cost = sample.get_hmm_cost(num_mismatches) + # logger.debug( + # f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}" + # ) + # if cost > cost_threshold: + # sample.path.clear() + # sample.mutations.clear() + # rerun_batch.append(sample) + + rerun_batch = samples + precision = 12 + logger.info(f"Rerunning batch of {len(rerun_batch)} at p={precision}") + match_tsinfer( + samples=rerun_batch, + ts=base_ts, + num_mismatches=num_mismatches, + precision=12, + num_threads=num_threads, + show_progress=show_progress, + ) + # for sample in samples_to_rerun: + # hmm_cost = sample.get_hmm_cost(num_mismatches) + # # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") + # logger.debug( + # f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" + # ) # remaining_samples = samples # for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: @@ -433,24 +462,8 @@ def match_samples( # samples_to_rerun.append(sample) # remaining_samples = samples_to_rerun - samples_to_rerun = samples - match_tsinfer( - samples=samples_to_rerun, - ts=base_ts, - num_mismatches=num_mismatches, - precision=12, - num_threads=num_threads, - show_progress=show_progress, - mirror_coordinates=mirror_coordinates, - ) - for sample in samples_to_rerun: - hmm_cost = sample.get_hmm_cost(num_mismatches) - # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") - logger.debug( - f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" - ) - - match_db.add(samples, date, num_mismatches) + # Return in sorted order so that results are deterministic + return sorted(samples, key=lambda s: s.strain) def check_base_ts(ts): @@ -526,7 +539,6 @@ def extend( min_group_size = 10 # TMP - precision = 6 check_base_ts(base_ts) logger.info( f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};" @@ -549,17 +561,16 @@ def extend( f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata" ) - match_samples( + samples = match_samples( date, samples, base_ts=base_ts, - match_db=match_db, num_mismatches=num_mismatches, show_progress=show_progress, num_threads=num_threads, - precision=precision, ) + match_db.add(samples, date, num_mismatches) match_db.create_mask_table(base_ts) ts = increment_time(date, base_ts) @@ -810,23 +821,21 @@ def solve_num_mismatches(ts, k): NOTE! This is NOT taking into account the spatial distance along the genome, and so is not a very good model in some ways. """ + # We can match against any node in tsinfer m = ts.num_sites - n = ts.num_nodes # We can match against any node in tsinfer - if k == 0: - # Pathological things happen when k=0 - r = 1e-3 - mu = 1e-20 - else: - # NOTE: the magnitude of mu matters because it puts a limit - # on how low we can push the HMM precision. We should be able to solve - # for the optimal value of this parameter such that the magnitude of the - # values within the HMM are as large as possible (so that we can truncate - # usefully). - mu = 1e-2 - denom = (1 - mu) ** k + (n - 1) * mu**k - r = n * mu**k / denom - assert mu < 0.5 - assert r < 0.5 + n = ts.num_nodes + # values of k <= 1 are not relevant for SC2 and lead to awkward corner cases + assert k > 1 + + # NOTE: the magnitude of mu matters because it puts a limit + # on how low we can push the HMM precision. We should be able to solve + # for the optimal value of this parameter such that the magnitude of the + # values within the HMM are as large as possible (so that we can truncate + # usefully). + # mu = 1e-2 + mu = 0.125 + denom = (1 - mu) ** k + (n - 1) * mu**k + r = n * mu**k / denom # Add a little bit of extra mass for recombination so that we deterministically # chose to recombine over k mutations diff --git a/tests/test_inference.py b/tests/test_inference.py index 6dda621..47bd114 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -561,6 +561,30 @@ def test_first_day(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db assert ts.tables.equals(fx_ts_map["2020-01-19"].tables, ignore_provenance=True) + def test_2020_02_02(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_db): + ts = sc2ts.extend( + alignment_store=fx_alignment_store, + metadata_db=fx_metadata_db, + base_ts=fx_ts_map["2020-02-01"], + date="2020-02-02", + match_db=sc2ts.MatchDb.initialise(tmp_path / "match.db"), + ) + assert ts.num_samples == 26 + assert np.sum(ts.nodes_time[ts.samples()] == 0) == 4 + samples = {} + for u in ts.samples()[-4:]: + node = ts.node(u) + samples[node.metadata["strain"]] = node + smd = node.metadata["sc2ts"] + md = node.metadata + print(md["date"], md["strain"], len(smd["mutations"])) + # print(samples) + # print(fx_ts_map["2020-02-01"]) + # print(ts) + # print(fx_ts_map["2020-02-02"]) + ts.tables.assert_equals(fx_ts_map["2020-02-02"].tables, ignore_provenance=True) + + @pytest.mark.parametrize("date", dates) def test_date_metadata(self, fx_ts_map, date): ts = fx_ts_map[date] @@ -675,7 +699,7 @@ class TestMatchingDetails: @pytest.mark.parametrize( ("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)] ) - @pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4]) + @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_exact_matches( self, @@ -707,7 +731,7 @@ def test_exact_matches( ("strain", "parent", "position", "derived_state"), [("SRR11597218", 10, 289, "T"), ("ERR4206593", 58, 26994, "T")], ) - @pytest.mark.parametrize("num_mismatches", [1, 2, 3, 4]) + @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) @pytest.mark.parametrize("precision", [0, 1, 2, 12]) def test_one_mismatch( self, From 375e2021dc25153711880466be2f9c1b40f1e179 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 16:46:17 +0100 Subject: [PATCH 4/7] Adjust test data for high-precision run --- sc2ts/inference.py | 19 ++++++++++--------- tests/test_inference.py | 30 +++++++++++++++--------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 2781627..da91cf3 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -421,22 +421,22 @@ def match_samples( # rerun_batch.append(sample) rerun_batch = samples - precision = 12 + precision = 6 logger.info(f"Rerunning batch of {len(rerun_batch)} at p={precision}") match_tsinfer( samples=rerun_batch, ts=base_ts, num_mismatches=num_mismatches, - precision=12, + precision=precision, num_threads=num_threads, show_progress=show_progress, ) - # for sample in samples_to_rerun: - # hmm_cost = sample.get_hmm_cost(num_mismatches) - # # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") - # logger.debug( - # f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" - # ) + for sample in rerun_batch: + hmm_cost = sample.get_hmm_cost(num_mismatches) + # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") + logger.debug( + f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" + ) # remaining_samples = samples # for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: @@ -463,7 +463,8 @@ def match_samples( # remaining_samples = samples_to_rerun # Return in sorted order so that results are deterministic - return sorted(samples, key=lambda s: s.strain) + # return sorted(samples, key=lambda s: s.strain) + return samples def check_base_ts(ts): diff --git a/tests/test_inference.py b/tests/test_inference.py index 47bd114..2df4012 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -612,17 +612,17 @@ def test_node_mutation_counts(self, fx_ts_map, date): "2020-01-30": {"nodes": 22, "mutations": 19}, "2020-01-31": {"nodes": 23, "mutations": 21}, "2020-02-01": {"nodes": 28, "mutations": 27}, - "2020-02-02": {"nodes": 34, "mutations": 36}, - "2020-02-03": {"nodes": 37, "mutations": 42}, - "2020-02-04": {"nodes": 42, "mutations": 48}, - "2020-02-05": {"nodes": 43, "mutations": 48}, - "2020-02-06": {"nodes": 49, "mutations": 51}, - "2020-02-07": {"nodes": 51, "mutations": 57}, - "2020-02-08": {"nodes": 57, "mutations": 58}, - "2020-02-09": {"nodes": 59, "mutations": 61}, - "2020-02-10": {"nodes": 60, "mutations": 65}, - "2020-02-11": {"nodes": 62, "mutations": 66}, - "2020-02-13": {"nodes": 66, "mutations": 68}, + "2020-02-02": {"nodes": 33, "mutations": 36}, + "2020-02-03": {"nodes": 36, "mutations": 42}, + "2020-02-04": {"nodes": 41, "mutations": 48}, + "2020-02-05": {"nodes": 42, "mutations": 48}, + "2020-02-06": {"nodes": 48, "mutations": 51}, + "2020-02-07": {"nodes": 50, "mutations": 57}, + "2020-02-08": {"nodes": 56, "mutations": 58}, + "2020-02-09": {"nodes": 58, "mutations": 61}, + "2020-02-10": {"nodes": 59, "mutations": 65}, + "2020-02-11": {"nodes": 61, "mutations": 66}, + "2020-02-13": {"nodes": 65, "mutations": 68}, } assert ts.num_nodes == expected[date]["nodes"] assert ts.num_mutations == expected[date]["mutations"] @@ -634,10 +634,10 @@ def test_node_mutation_counts(self, fx_ts_map, date): (7, "SRR11397729", 5), (13, "SRR11597132", 10), (16, "SRR11597177", 10), - (42, "SRR11597156", 10), - (57, "SRR11597216", 1), - (60, "SRR11597207", 41), - (62, "ERR4205570", 58), + (41, "SRR11597156", 10), + (56, "SRR11597216", 1), + (59, "SRR11597207", 40), + (61, "ERR4205570", 57), ], ) def test_exact_matches(self, fx_ts_map, node, strain, parent): From 86b800868bc4470830a55690ba418fb431bddf4e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 17:07:49 +0100 Subject: [PATCH 5/7] Work in progress --- sc2ts/inference.py | 88 ++++++++++++++++++---------------------------- 1 file changed, 35 insertions(+), 53 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index da91cf3..1c1869b 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -398,72 +398,54 @@ def match_samples( num_threads=None, ): # First pass, compute the matches at precision=0. - # precision = 0 - # match_tsinfer( - # samples=samples, - # ts=base_ts, - # num_mismatches=num_mismatches, - # precision=precision, - # num_threads=num_threads, - # show_progress=show_progress, - # ) - - # cost_threshold = 1 - # rerun_batch = [] - # for sample in samples: - # cost = sample.get_hmm_cost(num_mismatches) - # logger.debug( - # f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}" - # ) - # if cost > cost_threshold: - # sample.path.clear() - # sample.mutations.clear() - # rerun_batch.append(sample) - - rerun_batch = samples + run_batch = samples + + # WIP + for precision, cost_threshold in [(0, 0), (1, 1)]: # , (2, 2)]: + logger.info(f"Running batch of {len(run_batch)} at p={precision}") + match_tsinfer( + samples=run_batch, + ts=base_ts, + num_mismatches=num_mismatches, + precision=precision, + num_threads=num_threads, + show_progress=show_progress, + ) + + exceeding_threshold = [] + for sample in run_batch: + cost = sample.get_hmm_cost(num_mismatches) + logger.debug( + f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}" + ) + if cost > cost_threshold: + sample.path.clear() + sample.mutations.clear() + exceeding_threshold.append(sample) + + num_matches_found = len(run_batch) - len(exceeding_threshold) + logger.info( + f"{num_matches_found} final matches for found p={precision}; " + f"{len(exceeding_threshold)} remain" + ) + run_batch = exceeding_threshold + precision = 6 - logger.info(f"Rerunning batch of {len(rerun_batch)} at p={precision}") + logger.info(f"Running final batch of {len(run_batch)} at p={precision}") match_tsinfer( - samples=rerun_batch, + samples=run_batch, ts=base_ts, num_mismatches=num_mismatches, precision=precision, num_threads=num_threads, show_progress=show_progress, ) - for sample in rerun_batch: + for sample in run_batch: hmm_cost = sample.get_hmm_cost(num_mismatches) # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") logger.debug( f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" ) - - # remaining_samples = samples - # for cost, precision in [(0, 0), (1, 2)]: #, (2, 3)]: - # match_tsinfer( - # samples=remaining_samples, - # ts=base_ts, - # num_mismatches=num_mismatches, - # precision=precision, - # num_threads=num_threads, - # show_progress=show_progress, - # mirror_coordinates=mirror_coordinates, - # ) - # samples_to_rerun = [] - # for sample in remaining_samples: - # hmm_cost = sample.get_hmm_cost(num_mismatches) - # # print(f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}") - # logger.debug( - # f"HMM@p={precision}: {sample.strain} hmm_cost={hmm_cost} path={sample.path}" - # ) - # if hmm_cost > cost: - # sample.path.clear() - # sample.mutations.clear() - # samples_to_rerun.append(sample) - # remaining_samples = samples_to_rerun - - # Return in sorted order so that results are deterministic - # return sorted(samples, key=lambda s: s.strain) return samples From c7970e046081e6779f5122daa25d9cfef62c13c8 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 23:08:15 +0100 Subject: [PATCH 6/7] Fixup tests realising there are multiple correct answers here --- sc2ts/inference.py | 35 ++++++++++++++++++----------------- tests/test_inference.py | 36 ++++++++++++++++-------------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 1c1869b..3a9d2e9 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches): pkl_compressed, ) data.append(args) - pango = sample.metadata.get("Viridian_pangolin", "Unknown") - logger.debug( - f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}" - ) + logger.debug(f"MatchDB insert: hmm_cost={hmm_cost[j]} {sample.summary()}") # Batch insert, for efficiency. with self.conn: self.conn.executemany(sql, data) @@ -150,11 +147,7 @@ def get(self, where_clause): for row in self.conn.execute(sql): pkl = row.pop("pickle") sample = pickle.loads(bz2.decompress(pkl)) - pango = sample.metadata.get("Viridian_pangolin", "Unknown") - logger.debug( - f"MatchDb got: {sample.strain} {sample.date} {pango} " - f"hmm_cost={row['hmm_cost']}" - ) + logger.debug(f"MatchDb got: {sample.summary()} hmm_cost={row['hmm_cost']}") # print(row) yield sample @@ -364,6 +357,18 @@ class Sample: # def __str__(self): # return f"{self.strain}: {self.path} + {self.mutations}" + def path_summary(self): + return ",".join(f"({seg.left}:{seg.right}, {seg.parent})" for seg in self.path) + + def mutation_summary(self): + return "[" + ",".join(str(mutation) for mutation in self.mutations) + "]" + + def summary(self): + pango = self.metadata.get("Viridian_pangolin", "Unknown") + return (f"{self.strain} {self.date} {pango} path={self.path_summary()} " + f"mutations({len(self.mutations)})={self.mutation_summary()}" + ) + @property def breakpoints(self): breakpoints = [seg.left for seg in self.path] @@ -415,9 +420,7 @@ def match_samples( exceeding_threshold = [] for sample in run_batch: cost = sample.get_hmm_cost(num_mismatches) - logger.debug( - f"HMM@p={precision}: {sample.strain} hmm_cost={cost} path={sample.path}" - ) + logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}") if cost > cost_threshold: sample.path.clear() sample.mutations.clear() @@ -441,11 +444,9 @@ def match_samples( show_progress=show_progress, ) for sample in run_batch: - hmm_cost = sample.get_hmm_cost(num_mismatches) + cost = sample.get_hmm_cost(num_mismatches) # print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}") - logger.debug( - f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}" - ) + logger.debug(f"Final HMM pass hmm_cost={cost} {sample.summary()}") return samples @@ -1439,7 +1440,7 @@ def get_closest_mutation(node, site_id): sample.mutations.append( MatchMutation( site_id=site_id, - site_position=site_pos, + site_position=int(site_pos), derived_state=derived_state, inherited_state=inherited_state, is_reversion=is_reversion, diff --git a/tests/test_inference.py b/tests/test_inference.py index 2df4012..9e4550d 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -571,20 +571,12 @@ def test_2020_02_02(self, tmp_path, fx_ts_map, fx_alignment_store, fx_metadata_d ) assert ts.num_samples == 26 assert np.sum(ts.nodes_time[ts.samples()] == 0) == 4 - samples = {} - for u in ts.samples()[-4:]: - node = ts.node(u) - samples[node.metadata["strain"]] = node - smd = node.metadata["sc2ts"] - md = node.metadata - print(md["date"], md["strain"], len(smd["mutations"])) # print(samples) # print(fx_ts_map["2020-02-01"]) # print(ts) # print(fx_ts_map["2020-02-02"]) ts.tables.assert_equals(fx_ts_map["2020-02-02"].tables, ignore_provenance=True) - @pytest.mark.parametrize("date", dates) def test_date_metadata(self, fx_ts_map, date): ts = fx_ts_map[date] @@ -601,7 +593,11 @@ def test_date_validate(self, fx_ts_map, fx_alignment_store, date): @pytest.mark.parametrize("date", dates[1:]) def test_node_mutation_counts(self, fx_ts_map, date): - # Basic check to make sure our fixtures are what we expect + # Basic check to make sure our fixtures are what we expect. + # NOTE: this is somewhat fragile as the numbers of nodes does change + # a little depending on the exact solution that the HMM choses, for + # example when there are multiple single-mutation matches at different + # sites. ts = fx_ts_map[date] expected = { "2020-01-19": {"nodes": 3, "mutations": 3}, @@ -616,13 +612,13 @@ def test_node_mutation_counts(self, fx_ts_map, date): "2020-02-03": {"nodes": 36, "mutations": 42}, "2020-02-04": {"nodes": 41, "mutations": 48}, "2020-02-05": {"nodes": 42, "mutations": 48}, - "2020-02-06": {"nodes": 48, "mutations": 51}, - "2020-02-07": {"nodes": 50, "mutations": 57}, - "2020-02-08": {"nodes": 56, "mutations": 58}, - "2020-02-09": {"nodes": 58, "mutations": 61}, - "2020-02-10": {"nodes": 59, "mutations": 65}, - "2020-02-11": {"nodes": 61, "mutations": 66}, - "2020-02-13": {"nodes": 65, "mutations": 68}, + "2020-02-06": {"nodes": 49, "mutations": 51}, + "2020-02-07": {"nodes": 51, "mutations": 57}, + "2020-02-08": {"nodes": 57, "mutations": 58}, + "2020-02-09": {"nodes": 59, "mutations": 61}, + "2020-02-10": {"nodes": 60, "mutations": 65}, + "2020-02-11": {"nodes": 62, "mutations": 66}, + "2020-02-13": {"nodes": 66, "mutations": 68}, } assert ts.num_nodes == expected[date]["nodes"] assert ts.num_mutations == expected[date]["mutations"] @@ -635,9 +631,9 @@ def test_node_mutation_counts(self, fx_ts_map, date): (13, "SRR11597132", 10), (16, "SRR11597177", 10), (41, "SRR11597156", 10), - (56, "SRR11597216", 1), - (59, "SRR11597207", 40), - (61, "ERR4205570", 57), + (57, "SRR11597216", 1), + (60, "SRR11597207", 40), + (62, "ERR4205570", 58), ], ) def test_exact_matches(self, fx_ts_map, node, strain, parent): @@ -697,7 +693,7 @@ class TestMatchingDetails: # assert s.path[0].parent == 37 @pytest.mark.parametrize( - ("strain", "parent"), [("SRR11597207", 41), ("ERR4205570", 58)] + ("strain", "parent"), [("SRR11597207", 40), ("ERR4205570", 58)] ) @pytest.mark.parametrize("num_mismatches", [2, 3, 4]) @pytest.mark.parametrize("precision", [0, 1, 2, 12]) From dffa62c77d1b2cc9e8cfd2e24148bc7a83cbd717 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 3 Sep 2024 23:18:33 +0100 Subject: [PATCH 7/7] Try more agressive hmm cost thresolds --- sc2ts/inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 3a9d2e9..f534297 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -404,9 +404,10 @@ def match_samples( ): # First pass, compute the matches at precision=0. run_batch = samples - - # WIP - for precision, cost_threshold in [(0, 0), (1, 1)]: # , (2, 2)]: + + # Values based on https://github.com/jeromekelleher/sc2ts/issues/242, + # but somewhat arbitrary. + for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]: logger.info(f"Running batch of {len(run_batch)} at p={precision}") match_tsinfer( samples=run_batch,