From cca80b968eb6c6a93c40f626c2c645cf4256b8df Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 28 Aug 2024 14:42:03 +0100 Subject: [PATCH 1/3] Various updates getting inference working again Closes #218 Closes #223 --- sc2ts/inference.py | 84 +++++++++++++++++++++------------------------- 1 file changed, 38 insertions(+), 46 deletions(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index 51cd858..feee3fa 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -110,12 +110,7 @@ def create_mask_table(self, ts): # the rows in the DB that *are* in the ts, as a separate # transaction once we know that the trees have been saved to disk. logger.info("Loading used samples into DB") - # TODO this is inefficient - need some logging to see how much time - # we're spending here. - # One thing we can do is to store the list of strain IDs in the - # tree sequence top-level metadata, which we could even store using - # some numpy tricks to make it fast. - samples = [(ts.node(u).metadata["strain"],) for u in ts.samples()] + samples = [(strain,) for strain in ts.metadata["sc2ts"]["samples_strain"]] logger.debug(f"Got {len(samples)} from ts") with self.conn: self.conn.execute("DROP TABLE IF EXISTS used_samples") @@ -224,11 +219,17 @@ def initial_ts(): tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema) tables.reference_sequence.metadata = { "genbank_id": core.REFERENCE_GENBANK, - "notes": "X prepended to alignment to map from 1-based to 0-based coordinates" + "notes": "X prepended to alignment to map from 1-based to 0-based coordinates", } tables.reference_sequence.data = reference tables.metadata_schema = tskit.MetadataSchema(base_schema) + tables.metadata = { + "sc2ts": { + "date": core.REFERENCE_DATE, + "samples_strain": [core.REFERENCE_STRAIN], + } + } # TODO gene annotations to top level # TODO add known fields to the schemas and document them. @@ -245,7 +246,9 @@ def initial_ts(): # in later versions when we remove the dependence on tsinfer. tables.nodes.add_row(time=1, metadata={"strain": "Vestigial_ignore"}) tables.nodes.add_row( - flags=tskit.NODE_IS_SAMPLE, time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE} + flags=tskit.NODE_IS_SAMPLE, + time=0, + metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}, ) tables.edges.add_row(0, L, 0, 1) return tables.tree_sequence() @@ -255,42 +258,8 @@ def parse_date(date): return datetime.datetime.fromisoformat(date) -def filter_samples(samples, alignment_store, max_submission_delay=None): - if max_submission_delay is None: - max_submission_delay = 10**8 # Arbitrary large number of days. - not_in_store = 0 - num_filtered = 0 - ret = [] - for sample in samples: - if sample.strain not in alignment_store: - logger.warn(f"{sample.strain} not in alignment store") - not_in_store += 1 - continue - if sample.submission_delay < max_submission_delay: - ret.append(sample) - else: - num_filtered += 1 - if not_in_store == len(samples): - raise ValueError("All samples for day missing") - logger.info( - f"Filtered {num_filtered} samples with " - f"max_submission_delay >= {max_submission_delay}" - ) - return ret - - def last_date(ts): - if ts.num_samples == 0: - # Special case for the initial ts which contains the - # reference but not as a sample - u = ts.num_nodes - 1 - node = ts.node(u) - # assert node.time == 0 - return parse_date(node.metadata["date"]) - else: - samples = ts.samples() - samples_t0 = samples[ts.nodes_time[samples] == 0] - return max([parse_date(ts.node(u).metadata["date"]) for u in samples_t0]) + return parse_date(ts.metadata["sc2ts"]["date"]) def increment_time(date, ts): @@ -562,6 +531,14 @@ def match_samples( match_db.add(samples, date, num_mismatches) +def check_base_ts(ts): + md = ts.metadata + assert "sc2ts" in md + sc2ts_md = md["sc2ts"] + assert "date" in sc2ts_md + assert len(sc2ts_md["samples_strain"]) == ts.num_samples + + def extend( *, alignment_store, @@ -579,9 +556,10 @@ def extend( precision=None, rng=None, ): + check_base_ts(base_ts) logger.info( - f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};" - f"mutations={base_ts.num_mutations}" + f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};" + f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}" ) # TODO not sure whether we'll keep these params. Making sure they're not # used for now @@ -640,7 +618,21 @@ def extend( min_group_size=min_group_size, show_progress=show_progress, ) - return ts + return update_top_level_metadata(ts, date) + + +def update_top_level_metadata(ts, date): + tables = ts.dump_tables() + md = tables.metadata + md["sc2ts"]["date"] = date + samples_strain = md["sc2ts"]["samples_strain"] + new_samples = ts.samples()[len(samples_strain) :] + for u in new_samples: + node = ts.node(u) + samples_strain.append(node.metadata["strain"]) + md["sc2ts"]["samples_strain"] = samples_strain + tables.metadata = md + return tables.tree_sequence() def match_path_ts(samples, ts, path, reversions): From 80c4a6cfc8e5828d1782c23275c7b3cdba4ba436 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 28 Aug 2024 15:07:54 +0100 Subject: [PATCH 2/3] Add additional-problematic-sites CLI ARG Closes #239 --- sc2ts/cli.py | 26 +++++++++++++++++++++++++- sc2ts/core.py | 25 +------------------------ sc2ts/inference.py | 13 +++++++------ 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/sc2ts/cli.py b/sc2ts/cli.py index 14728f0..971ab7d 100644 --- a/sc2ts/cli.py +++ b/sc2ts/cli.py @@ -11,6 +11,7 @@ import datetime import pickle +import numpy as np import tqdm import tskit import tszip @@ -22,6 +23,8 @@ from . import core from . import inference +logger = logging.getLogger(__name__) + def get_environment(): """ @@ -230,6 +233,12 @@ def dump_samples(samples, output_file): @click.option("--num-threads", default=0, type=int, help="Number of match threads") @click.option("--random-seed", default=42, type=int, help="Random seed for subsampling") @click.option("--stop-date", default="2030-01-01", type=str, help="Stopping date") +@click.option( + "--additional-problematic-sites", + default=None, + type=str, + help="File containing the list of additional problematic sites to exclude.", +) @click.option("-p", "--precision", default=None, type=int, help="Match precision") @click.option("--no-progress", default=False, type=bool, help="Don't show progress") @click.option("-v", "--verbose", count=True) @@ -248,6 +257,7 @@ def daily_extend( num_threads, random_seed, stop_date, + additional_problematic_sites, precision, no_progress, verbose, @@ -259,13 +269,27 @@ def daily_extend( setup_logging(verbose, log_file) rng = random.Random(random_seed) + additional_problematic = [] + if additional_problematic_sites is not None: + additional_problematic = ( + np.loadtxt(additional_problematic_sites).astype(int).tolist() + ) + logger.info( + f"Excluding additional {len(additional_problematic)} problematic sites" + ) + match_db_path = f"{output_prefix}match.db" if base is None: - base_ts = inference.initial_ts() + base_ts = inference.initial_ts(additional_problematic) match_db = inference.MatchDb.initialise(match_db_path) else: base_ts = tskit.load(base) + assert ( + base_ts.metadata["sc2ts"]["additional_problematic_sites"] + == additional_problematic + ) + with contextlib.ExitStack() as exit_stack: alignment_store = exit_stack.enter_context(sc2ts.AlignmentStore(alignments)) metadata_db = exit_stack.enter_context(sc2ts.MetadataDb(metadata)) diff --git a/sc2ts/core.py b/sc2ts/core.py index bbedaef..faf02e1 100644 --- a/sc2ts/core.py +++ b/sc2ts/core.py @@ -50,30 +50,7 @@ def __len__(self): def get_problematic_sites(): - base = np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64) - # Temporary to try out removing these outliers. See - # https://github.com/jeromekelleher/sc2ts/issues/231#issuecomment-2306665447 - # In reality we'd probably want to provide an additional file of extra sites - # to remove. - additional = [ - 7851, - 10323, - 11750, - 17040, - 21137, - 21846, - 22917, - 22995, - 26681, - 27384, - 27638, - 27752, - 28254, - 28271, - 29614, - ] - full = np.append(base, additional) - return np.sort(full) + return np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64) __cached_reference = None diff --git a/sc2ts/inference.py b/sc2ts/inference.py index feee3fa..aa058d6 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -4,7 +4,6 @@ import datetime import dataclasses import collections -import json import pickle import os import sqlite3 @@ -77,7 +76,6 @@ def add(self, samples, date, num_mismatches): data = [] hmm_cost = np.zeros(len(samples)) for j, sample in enumerate(samples): - d = sample.asdict() assert sample.date == date # FIXME we want to be more selective about what we're storing # here, as we're including the alignment too. @@ -207,14 +205,17 @@ def mirror_ts_coordinates(ts): return tables.tree_sequence() -def initial_ts(): +def initial_ts(additional_problematic_sites=list()): reference = core.get_reference_sequence() L = core.REFERENCE_SEQUENCE_LENGTH assert L == len(reference) - problematic_sites = set(core.get_problematic_sites()) + problematic_sites = set(core.get_problematic_sites()) | set(additional_problematic_sites) tables = tskit.TableCollection(L) tables.time_units = core.TIME_UNITS + + # TODO add known fields to the schemas and document them. + base_schema = tskit.MetadataSchema.permissive_json().schema tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema) tables.reference_sequence.metadata = { @@ -224,15 +225,15 @@ def initial_ts(): tables.reference_sequence.data = reference tables.metadata_schema = tskit.MetadataSchema(base_schema) + # TODO gene annotations to top level tables.metadata = { "sc2ts": { "date": core.REFERENCE_DATE, "samples_strain": [core.REFERENCE_STRAIN], + "additional_problematic_sites": additional_problematic_sites, } } - # TODO gene annotations to top level - # TODO add known fields to the schemas and document them. tables.nodes.metadata_schema = tskit.MetadataSchema(base_schema) tables.sites.metadata_schema = tskit.MetadataSchema(base_schema) tables.mutations.metadata_schema = tskit.MetadataSchema(base_schema) From 43017c19b3bd7236df8f55fcd84500713e4ccb30 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 28 Aug 2024 15:34:18 +0100 Subject: [PATCH 3/3] Fix validate --- sc2ts/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sc2ts/inference.py b/sc2ts/inference.py index aa058d6..148d9f9 100644 --- a/sc2ts/inference.py +++ b/sc2ts/inference.py @@ -313,7 +313,7 @@ def validate(ts, alignment_store, show_progress=False): Check that all the samples in the specified tree sequence are correctly representing the original alignments. """ - samples = ts.samples() + samples = ts.samples()[1:] chunk_size = 10**3 offset = 0 num_chunks = ts.num_samples // chunk_size