Skip to content

Commit

Permalink
Merge pull request #246 from jeromekelleher/improvements
Browse files Browse the repository at this point in the history
Improvements
  • Loading branch information
jeromekelleher authored Aug 28, 2024
2 parents 5c80d4f + 43017c1 commit 83fcb1a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 78 deletions.
26 changes: 25 additions & 1 deletion sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datetime
import pickle

import numpy as np
import tqdm
import tskit
import tszip
Expand All @@ -22,6 +23,8 @@
from . import core
from . import inference

logger = logging.getLogger(__name__)


def get_environment():
"""
Expand Down Expand Up @@ -230,6 +233,12 @@ def dump_samples(samples, output_file):
@click.option("--num-threads", default=0, type=int, help="Number of match threads")
@click.option("--random-seed", default=42, type=int, help="Random seed for subsampling")
@click.option("--stop-date", default="2030-01-01", type=str, help="Stopping date")
@click.option(
"--additional-problematic-sites",
default=None,
type=str,
help="File containing the list of additional problematic sites to exclude.",
)
@click.option("-p", "--precision", default=None, type=int, help="Match precision")
@click.option("--no-progress", default=False, type=bool, help="Don't show progress")
@click.option("-v", "--verbose", count=True)
Expand All @@ -248,6 +257,7 @@ def daily_extend(
num_threads,
random_seed,
stop_date,
additional_problematic_sites,
precision,
no_progress,
verbose,
Expand All @@ -259,13 +269,27 @@ def daily_extend(
setup_logging(verbose, log_file)
rng = random.Random(random_seed)

additional_problematic = []
if additional_problematic_sites is not None:
additional_problematic = (
np.loadtxt(additional_problematic_sites).astype(int).tolist()
)
logger.info(
f"Excluding additional {len(additional_problematic)} problematic sites"
)

match_db_path = f"{output_prefix}match.db"
if base is None:
base_ts = inference.initial_ts()
base_ts = inference.initial_ts(additional_problematic)
match_db = inference.MatchDb.initialise(match_db_path)
else:
base_ts = tskit.load(base)

assert (
base_ts.metadata["sc2ts"]["additional_problematic_sites"]
== additional_problematic
)

with contextlib.ExitStack() as exit_stack:
alignment_store = exit_stack.enter_context(sc2ts.AlignmentStore(alignments))
metadata_db = exit_stack.enter_context(sc2ts.MetadataDb(metadata))
Expand Down
25 changes: 1 addition & 24 deletions sc2ts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,7 @@ def __len__(self):


def get_problematic_sites():
base = np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64)
# Temporary to try out removing these outliers. See
# https://github.com/jeromekelleher/sc2ts/issues/231#issuecomment-2306665447
# In reality we'd probably want to provide an additional file of extra sites
# to remove.
additional = [
7851,
10323,
11750,
17040,
21137,
21846,
22917,
22995,
26681,
27384,
27638,
27752,
28254,
28271,
29614,
]
full = np.append(base, additional)
return np.sort(full)
return np.loadtxt(data_path / "problematic_sites.txt", dtype=np.int64)


__cached_reference = None
Expand Down
99 changes: 46 additions & 53 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import datetime
import dataclasses
import collections
import json
import pickle
import os
import sqlite3
Expand Down Expand Up @@ -77,7 +76,6 @@ def add(self, samples, date, num_mismatches):
data = []
hmm_cost = np.zeros(len(samples))
for j, sample in enumerate(samples):
d = sample.asdict()
assert sample.date == date
# FIXME we want to be more selective about what we're storing
# here, as we're including the alignment too.
Expand Down Expand Up @@ -110,12 +108,7 @@ def create_mask_table(self, ts):
# the rows in the DB that *are* in the ts, as a separate
# transaction once we know that the trees have been saved to disk.
logger.info("Loading used samples into DB")
# TODO this is inefficient - need some logging to see how much time
# we're spending here.
# One thing we can do is to store the list of strain IDs in the
# tree sequence top-level metadata, which we could even store using
# some numpy tricks to make it fast.
samples = [(ts.node(u).metadata["strain"],) for u in ts.samples()]
samples = [(strain,) for strain in ts.metadata["sc2ts"]["samples_strain"]]
logger.debug(f"Got {len(samples)} from ts")
with self.conn:
self.conn.execute("DROP TABLE IF EXISTS used_samples")
Expand Down Expand Up @@ -212,26 +205,35 @@ def mirror_ts_coordinates(ts):
return tables.tree_sequence()


def initial_ts():
def initial_ts(additional_problematic_sites=list()):
reference = core.get_reference_sequence()
L = core.REFERENCE_SEQUENCE_LENGTH
assert L == len(reference)
problematic_sites = set(core.get_problematic_sites())
problematic_sites = set(core.get_problematic_sites()) | set(additional_problematic_sites)

tables = tskit.TableCollection(L)
tables.time_units = core.TIME_UNITS

# TODO add known fields to the schemas and document them.

base_schema = tskit.MetadataSchema.permissive_json().schema
tables.reference_sequence.metadata_schema = tskit.MetadataSchema(base_schema)
tables.reference_sequence.metadata = {
"genbank_id": core.REFERENCE_GENBANK,
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates"
"notes": "X prepended to alignment to map from 1-based to 0-based coordinates",
}
tables.reference_sequence.data = reference

tables.metadata_schema = tskit.MetadataSchema(base_schema)

# TODO gene annotations to top level
# TODO add known fields to the schemas and document them.
tables.metadata = {
"sc2ts": {
"date": core.REFERENCE_DATE,
"samples_strain": [core.REFERENCE_STRAIN],
"additional_problematic_sites": additional_problematic_sites,
}
}

tables.nodes.metadata_schema = tskit.MetadataSchema(base_schema)
tables.sites.metadata_schema = tskit.MetadataSchema(base_schema)
tables.mutations.metadata_schema = tskit.MetadataSchema(base_schema)
Expand All @@ -245,7 +247,9 @@ def initial_ts():
# in later versions when we remove the dependence on tsinfer.
tables.nodes.add_row(time=1, metadata={"strain": "Vestigial_ignore"})
tables.nodes.add_row(
flags=tskit.NODE_IS_SAMPLE, time=0, metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE}
flags=tskit.NODE_IS_SAMPLE,
time=0,
metadata={"strain": core.REFERENCE_STRAIN, "date": core.REFERENCE_DATE},
)
tables.edges.add_row(0, L, 0, 1)
return tables.tree_sequence()
Expand All @@ -255,42 +259,8 @@ def parse_date(date):
return datetime.datetime.fromisoformat(date)


def filter_samples(samples, alignment_store, max_submission_delay=None):
if max_submission_delay is None:
max_submission_delay = 10**8 # Arbitrary large number of days.
not_in_store = 0
num_filtered = 0
ret = []
for sample in samples:
if sample.strain not in alignment_store:
logger.warn(f"{sample.strain} not in alignment store")
not_in_store += 1
continue
if sample.submission_delay < max_submission_delay:
ret.append(sample)
else:
num_filtered += 1
if not_in_store == len(samples):
raise ValueError("All samples for day missing")
logger.info(
f"Filtered {num_filtered} samples with "
f"max_submission_delay >= {max_submission_delay}"
)
return ret


def last_date(ts):
if ts.num_samples == 0:
# Special case for the initial ts which contains the
# reference but not as a sample
u = ts.num_nodes - 1
node = ts.node(u)
# assert node.time == 0
return parse_date(node.metadata["date"])
else:
samples = ts.samples()
samples_t0 = samples[ts.nodes_time[samples] == 0]
return max([parse_date(ts.node(u).metadata["date"]) for u in samples_t0])
return parse_date(ts.metadata["sc2ts"]["date"])


def increment_time(date, ts):
Expand Down Expand Up @@ -343,7 +313,7 @@ def validate(ts, alignment_store, show_progress=False):
Check that all the samples in the specified tree sequence are correctly
representing the original alignments.
"""
samples = ts.samples()
samples = ts.samples()[1:]
chunk_size = 10**3
offset = 0
num_chunks = ts.num_samples // chunk_size
Expand Down Expand Up @@ -562,6 +532,14 @@ def match_samples(
match_db.add(samples, date, num_mismatches)


def check_base_ts(ts):
md = ts.metadata
assert "sc2ts" in md
sc2ts_md = md["sc2ts"]
assert "date" in sc2ts_md
assert len(sc2ts_md["samples_strain"]) == ts.num_samples


def extend(
*,
alignment_store,
Expand All @@ -579,9 +557,10 @@ def extend(
precision=None,
rng=None,
):
check_base_ts(base_ts)
logger.info(
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
f"mutations={base_ts.num_mutations}"
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
f"mutations={base_ts.num_mutations};date={base_ts.metadata['sc2ts']['date']}"
)
# TODO not sure whether we'll keep these params. Making sure they're not
# used for now
Expand Down Expand Up @@ -640,7 +619,21 @@ def extend(
min_group_size=min_group_size,
show_progress=show_progress,
)
return ts
return update_top_level_metadata(ts, date)


def update_top_level_metadata(ts, date):
tables = ts.dump_tables()
md = tables.metadata
md["sc2ts"]["date"] = date
samples_strain = md["sc2ts"]["samples_strain"]
new_samples = ts.samples()[len(samples_strain) :]
for u in new_samples:
node = ts.node(u)
samples_strain.append(node.metadata["strain"])
md["sc2ts"]["samples_strain"] = samples_strain
tables.metadata = md
return tables.tree_sequence()


def match_path_ts(samples, ts, path, reversions):
Expand Down

0 comments on commit 83fcb1a

Please sign in to comment.