Skip to content

Commit

Permalink
Merge pull request #217 from jeromekelleher/run-full-viridian
Browse files Browse the repository at this point in the history
Run full viridian
  • Loading branch information
jeromekelleher authored Aug 1, 2024
2 parents f4855f2 + 726f717 commit c24cd33
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 67 deletions.
7 changes: 6 additions & 1 deletion sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,12 @@ def setup_logging(verbosity, log_file=None):
# at the console output. For development this is better than having
# to go to the log to see the traceback, but for production it may
# be better to let daiquiri record the errors as well.
daiquiri.setup(level=log_level, outputs=outputs, set_excepthook=False)
daiquiri.setup(outputs=outputs, set_excepthook=False)
# Only show stuff coming from sc2ts. Sometimes it's handy to look
# at the tsinfer logs too, so we could add an option to set its
# levels
logger = logging.getLogger("sc2ts")
logger.setLevel(log_level)


# TODO add options to list keys, dump specific alignments etc
Expand Down
146 changes: 93 additions & 53 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def add(self, samples, date, num_mismatches):
for j, sample in enumerate(samples):
d = sample.asdict()
assert sample.date == date
# FIXME we want to be more selective about what we're storing
# here, as we're including the alignment too.
pkl = pickle.dumps(sample)
# BZ2 compressing drops this by ~10X, so worth it.
pkl_compressed = bz2.compress(pkl)
Expand All @@ -75,6 +77,10 @@ def add(self, samples, date, num_mismatches):
pkl_compressed,
)
data.append(args)
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
)
# Batch insert, for efficiency.
with self.conn:
self.conn.executemany(sql, data)
Expand Down Expand Up @@ -124,7 +130,11 @@ def get(self, where_clause):
for row in self.conn.execute(sql):
pkl = row.pop("pickle")
sample = pickle.loads(bz2.decompress(pkl))
logger.debug(f"MatchDb got: {row}")
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDb got: {sample.strain} {sample.date} {pango} "
f"hmm_cost={row['hmm_cost']}"
)
# print(row)
yield sample

Expand All @@ -149,19 +159,20 @@ def initialise(db_path):
)
return MatchDb(db_path)


def print_all(self):
"""
Debug method to print out full state of the DB.
"""
import pandas as pd

data = []
with self.conn:
for row in self.conn.execute("SELECT * from samples"):
data.append(row)
df = pd.DataFrame(row, index=["strain"])
print(df)


def mirror(x, L):
return L - x

Expand Down Expand Up @@ -253,7 +264,7 @@ def last_date(ts):
# reference but not as a sample
u = ts.num_nodes - 1
node = ts.node(u)
assert node.time == 0
# assert node.time == 0
return parse_date(node.metadata["date"])
else:
samples = ts.samples()
Expand Down Expand Up @@ -336,6 +347,10 @@ class Sample:
mutations: List = dataclasses.field(default_factory=list)
alignment_qc: Dict = dataclasses.field(default_factory=dict)
masked_sites: List = dataclasses.field(default_factory=list)
# FIXME need a better name for this, as it's a different thing
# the original alignment. Haplotype is probably good, as it's
# what it would be in the tskit/tsinfer world.
alignment: List = None

# def __repr__(self):
# return self.strain
Expand All @@ -352,18 +367,6 @@ def breakpoints(self):
def parents(self):
return [seg.parent for seg in self.path]

# @property
# def date(self):
# return parse_date(self.metadata["date"])

# @property
# def submission_date(self):
# return parse_date(self.metadata["date_submitted"])

# @property
# def submission_delay(self):
# return (self.submission_date - self.date).days

def get_hmm_cost(self, num_mismatches):
# Note that Recombinant objects have total_cost.
# This bit of code is sort of repeated.
Expand Down Expand Up @@ -424,70 +427,84 @@ def daily_extend(
last_ts = ts


def preprocess_and_match_alignments(
def preprocess(
date,
*,
base_ts,
metadata_db,
alignment_store,
match_db,
base_ts,
num_mismatches=None,
show_progress=False,
num_threads=None,
precision=None,
max_daily_samples=None,
mirror_coordinates=False,
show_progress=False,
):
if num_mismatches is None:
# Default to no recombination
num_mismatches = 1000

samples = []
for md in metadata_db.get(date):
samples.append(Sample(md["strain"], md["date"], md))
if len(samples) == 0:
logger.warn(f"Zero samples for {date}")
return
metadata_matches = list(metadata_db.get(date))

if len(metadata_matches) == 0:
logger.warn(f"Zero metadata matches for {date}")
return []

if date.endswith("01-01"):
logger.warning(f"Skipping {len(metadata_matches)} samples for {date}")
return []

# TODO implement this.
assert max_daily_samples is None

# Note: there's not a lot of point in making the G matrix here,
# we should just pass on the encoded alignments to the matching
# algorithm directly through the Sample class, and let it
# do the low-level haplotype storage.
G = np.zeros((base_ts.num_sites, len(samples)), dtype=np.int8)
keep_sites = base_ts.sites_position.astype(int)
problematic_sites = core.get_problematic_sites()
samples = []

samples_iter = enumerate(samples)
with tqdm.tqdm(
samples_iter,
desc=f"Fetch:{date}",
total=len(samples),
metadata_matches,
desc=f"Preprocess:{date}",
disable=not show_progress,
) as bar:
for j, sample in bar:
logger.debug(f"Getting alignment for {sample.strain}")
alignment = alignment_store[sample.strain]
sample.alignment = alignment
logger.debug("Encoding alignment")
for md in bar:
strain = md["strain"]
logger.debug(f"Getting alignment for {strain}")
try:
alignment = alignment_store[strain]
except KeyError:
logger.debug(f"No alignment stored for {strain}")
continue

sample = Sample(strain, date, metadata=md)
ma = alignments.encode_and_mask(alignment)
# Always mask the problematic_sites as well. We need to do this
# for follow-up matching to inspect recombinants, as tsinfer
# needs us to keep all sites in the table when doing mirrored
# coordinates.
ma.alignment[problematic_sites] = -1
G[:, j] = ma.alignment[keep_sites]
sample.alignment_qc = ma.qc_summary()
sample.masked_sites = ma.masked_sites
sample.alignment = ma.alignment[keep_sites]
samples.append(sample)

masked_per_sample = np.mean([len(sample.masked_sites)])
logger.info(f"Masked average of {masked_per_sample:.2f} nucleotides per sample")
logger.info(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)
return samples


def match_samples(
date,
samples,
*,
match_db,
base_ts,
num_mismatches=None,
show_progress=False,
num_threads=None,
precision=None,
mirror_coordinates=False,
):
if num_mismatches is None:
# Default to no recombination
num_mismatches = 1000

match_tsinfer(
samples=samples,
ts=base_ts,
genotypes=G,
num_mismatches=num_mismatches,
precision=precision,
num_threads=num_threads,
Expand Down Expand Up @@ -515,21 +532,36 @@ def extend(
precision=None,
rng=None,
):
logger.info(
f"Extend {date}; ts:nodes={base_ts.num_nodes};edges={base_ts.num_edges};"
f"mutations={base_ts.num_mutations}"
)
# TODO not sure whether we'll keep these params. Making sure they're not
# used for now
assert max_submission_delay is None

preprocess_and_match_alignments(
samples = preprocess(
date,
metadata_db=metadata_db,
alignment_store=alignment_store,
base_ts=base_ts,
max_daily_samples=max_daily_samples,
show_progress=show_progress,
)

if len(samples) == 0:
logger.warning(f"Nothing to do for {date}")
return base_ts

match_samples(
date,
samples,
base_ts=base_ts,
match_db=match_db,
num_mismatches=num_mismatches,
show_progress=show_progress,
num_threads=num_threads,
precision=precision,
max_daily_samples=max_daily_samples,
)

match_db.create_mask_table(base_ts)
Expand Down Expand Up @@ -574,6 +606,10 @@ def match_path_ts(samples, ts, path, reversions):
path = samples[0].path
site_id_map = {}
first_sample = len(tables.nodes)
logger.debug(
f"Adding group of {len(samples)} with path={path} and "
f"reversions={reversions}"
)
for sample in samples:
assert sample.path == path
metadata = {
Expand All @@ -596,6 +632,10 @@ def match_path_ts(samples, ts, path, reversions):
# Now add the mutations
for node_id, sample in enumerate(samples, first_sample):
# metadata = {**sample.metadata, "sc2ts_qc": sample.alignment_qc}
logger.debug(
f"Adding {sample.strain}:{sample.date} with "
f"{len(sample.mutations)} mutations"
)
for mut in sample.mutations:
tables.mutations.add_row(
site=site_id_map[mut.site_id],
Expand Down Expand Up @@ -1210,14 +1250,14 @@ def resize_copy(array, new_size):
def match_tsinfer(
samples,
ts,
genotypes,
*,
num_mismatches,
precision=None,
num_threads=0,
show_progress=False,
mirror_coordinates=False,
):
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
input_ts = ts
if mirror_coordinates:
ts = mirror_ts_coordinates(ts)
Expand Down
27 changes: 14 additions & 13 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,9 @@ def test_two_samples_one_mutation_one_filtered(self, tmp_path):


class TestMatchTsinfer:
def match_tsinfer(self, samples, ts, haplotypes, **kwargs):
assert len(samples) == len(haplotypes)
G = np.array(haplotypes).T
def match_tsinfer(self, samples, ts, **kwargs):
sc2ts.inference.match_tsinfer(
samples=samples, ts=ts, genotypes=G, num_mismatches=1000, **kwargs
samples=samples, ts=ts, num_mismatches=1000, **kwargs
)

@pytest.mark.parametrize("mirror", [False, True])
Expand All @@ -189,10 +187,11 @@ def test_match_reference(self, mirror):
tables.sites.truncate(20)
ts = tables.tree_sequence()
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
samples[0].alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(alignment)
h = ma.alignment[ts.sites_position.astype(int)]
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
samples[0].alignment = h
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
assert samples[0].breakpoints == [0, ts.sequence_length]
assert samples[0].parents == [ts.num_nodes - 1]
assert len(samples[0].mutations) == 0
Expand All @@ -205,12 +204,13 @@ def test_match_reference_one_mutation(self, mirror, site_id):
tables.sites.truncate(20)
ts = tables.tree_sequence()
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
samples[0].alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(alignment)
h = ma.alignment[ts.sites_position.astype(int)]
# Mutate to gap
h[site_id] = sc2ts.core.ALLELES.index("-")
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
samples[0].alignment = h
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
assert samples[0].breakpoints == [0, ts.sequence_length]
assert samples[0].parents == [ts.num_nodes - 1]
assert len(samples[0].mutations) == 1
Expand All @@ -230,11 +230,12 @@ def test_match_reference_all_same(self, mirror, allele):
tables.sites.truncate(20)
ts = tables.tree_sequence()
samples = util.get_samples(ts, [[(0, ts.sequence_length, 1)]])
samples[0].alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(samples[0].alignment)
alignment = sc2ts.core.get_reference_sequence()
ma = sc2ts.alignments.encode_and_mask(alignment)
ref = ma.alignment[ts.sites_position.astype(int)]
h = np.zeros_like(ref) + allele
self.match_tsinfer(samples, ts, [h], mirror_coordinates=mirror)
samples[0].alignment = h
self.match_tsinfer(samples, ts, mirror_coordinates=mirror)
assert samples[0].breakpoints == [0, ts.sequence_length]
assert samples[0].parents == [ts.num_nodes - 1]
muts = samples[0].mutations
Expand Down

0 comments on commit c24cd33

Please sign in to comment.