Skip to content

Commit

Permalink
Merge pull request #256 from jeromekelleher/dynamic-precision
Browse files Browse the repository at this point in the history
Dynamic precision
  • Loading branch information
jeromekelleher authored Sep 3, 2024
2 parents e996bfd + dffa62c commit 6956365
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 191 deletions.
6 changes: 3 additions & 3 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ num_threads=8

# Paths
datadir=testrun
run_id=tmp-dev
run_id=tmp-dev-hp
# run_id=upgma-mds-$max_daily_samples-md-$max_submission_delay-mm-$mismatches
resultsdir=results/$run_id
results_prefix=$resultsdir/$run_id-
logfile=logs/$run_id.log

alignments=$datadir/alignments.db
metadata=$datadir/metadata.db
matches=$resultsdir/matces.db
matches=$resultsdir/matches.db

dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31`
dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14`
echo $dates

options="--num-threads $num_threads -vv -l $logfile "
Expand Down
213 changes: 88 additions & 125 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def add(self, samples, date, num_mismatches):
pkl_compressed,
)
data.append(args)
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDB insert: {sample.strain} {date} {pango} hmm_cost={hmm_cost[j]}"
)
logger.debug(f"MatchDB insert: hmm_cost={hmm_cost[j]} {sample.summary()}")
# Batch insert, for efficiency.
with self.conn:
self.conn.executemany(sql, data)
Expand Down Expand Up @@ -150,11 +147,7 @@ def get(self, where_clause):
for row in self.conn.execute(sql):
pkl = row.pop("pickle")
sample = pickle.loads(bz2.decompress(pkl))
pango = sample.metadata.get("Viridian_pangolin", "Unknown")
logger.debug(
f"MatchDb got: {sample.strain} {sample.date} {pango} "
f"hmm_cost={row['hmm_cost']}"
)
logger.debug(f"MatchDb got: {sample.summary()} hmm_cost={row['hmm_cost']}")
# print(row)
yield sample

Expand Down Expand Up @@ -364,6 +357,18 @@ class Sample:
# def __str__(self):
# return f"{self.strain}: {self.path} + {self.mutations}"

def path_summary(self):
return ",".join(f"({seg.left}:{seg.right}, {seg.parent})" for seg in self.path)

def mutation_summary(self):
return "[" + ",".join(str(mutation) for mutation in self.mutations) + "]"

def summary(self):
pango = self.metadata.get("Viridian_pangolin", "Unknown")
return (f"{self.strain} {self.date} {pango} path={self.path_summary()} "
f"mutations({len(self.mutations)})={self.mutation_summary()}"
)

@property
def breakpoints(self):
breakpoints = [seg.left for seg in self.path]
Expand All @@ -388,104 +393,62 @@ def asdict(self):
}


# def daily_extend(
# *,
# alignment_store,
# metadata_db,
# base_ts,
# match_db,
# num_mismatches=None,
# max_hmm_cost=None,
# min_group_size=None,
# num_past_days=None,
# show_progress=False,
# max_submission_delay=None,
# max_daily_samples=None,
# num_threads=None,
# precision=None,
# rng=None,
# excluded_sample_dir=None,
# ):
# assert num_past_days is None
# assert max_submission_delay is None

# start_day = last_date(base_ts)

# last_ts = base_ts
# for date in metadata_db.get_days(start_day):
# ts = extend(
# alignment_store=alignment_store,
# metadata_db=metadata_db,
# date=date,
# base_ts=last_ts,
# match_db=match_db,
# num_mismatches=num_mismatches,
# max_hmm_cost=max_hmm_cost,
# min_group_size=min_group_size,
# show_progress=show_progress,
# max_submission_delay=max_submission_delay,
# max_daily_samples=max_daily_samples,
# num_threads=num_threads,
# precision=precision,
# )
# yield ts, date

# last_ts = ts


def match_samples(
date,
samples,
*,
match_db,
base_ts,
num_mismatches=None,
show_progress=False,
num_threads=None,
precision=None,
mirror_coordinates=False,
):
if num_mismatches is None:
# Default to no recombination
num_mismatches = 1000

match_tsinfer(
samples=samples,
ts=base_ts,
num_mismatches=num_mismatches,
precision=2,
num_threads=num_threads,
show_progress=show_progress,
mirror_coordinates=mirror_coordinates,
)
samples_to_rerun = []
for sample in samples:
hmm_cost = sample.get_hmm_cost(num_mismatches)
logger.debug(
f"First sketch: {sample.strain} hmm_cost={hmm_cost} path={sample.path}"
)
if hmm_cost >= 2:
sample.path.clear()
sample.mutations.clear()
samples_to_rerun.append(sample)

if len(samples_to_rerun) > 0:
# First pass, compute the matches at precision=0.
run_batch = samples

# Values based on https://github.com/jeromekelleher/sc2ts/issues/242,
# but somewhat arbitrary.
for precision, cost_threshold in [(0, 1), (1, 2), (2, 3)]:
logger.info(f"Running batch of {len(run_batch)} at p={precision}")
match_tsinfer(
samples=samples_to_rerun,
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
num_threads=num_threads,
show_progress=show_progress,
mirror_coordinates=mirror_coordinates,
)
for sample in samples_to_rerun:
hmm_cost = sample.get_hmm_cost(num_mismatches)
logger.debug(
f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}"
)

match_db.add(samples, date, num_mismatches)
exceeding_threshold = []
for sample in run_batch:
cost = sample.get_hmm_cost(num_mismatches)
logger.debug(f"HMM@p={precision}: hmm_cost={cost} {sample.summary()}")
if cost > cost_threshold:
sample.path.clear()
sample.mutations.clear()
exceeding_threshold.append(sample)

num_matches_found = len(run_batch) - len(exceeding_threshold)
logger.info(
f"{num_matches_found} final matches for found p={precision}; "
f"{len(exceeding_threshold)} remain"
)
run_batch = exceeding_threshold

precision = 6
logger.info(f"Running final batch of {len(run_batch)} at p={precision}")
match_tsinfer(
samples=run_batch,
ts=base_ts,
num_mismatches=num_mismatches,
precision=precision,
num_threads=num_threads,
show_progress=show_progress,
)
for sample in run_batch:
cost = sample.get_hmm_cost(num_mismatches)
# print(f"Final HMM pass:{sample.strain} hmm_cost={hmm_cost} path={sample.path}")
logger.debug(f"Final HMM pass hmm_cost={cost} {sample.summary()}")
return samples


def check_base_ts(ts):
Expand Down Expand Up @@ -561,7 +524,6 @@ def extend(
min_group_size = 10

# TMP
precision = 6
check_base_ts(base_ts)
logger.info(
f"Extend {date}; ts:nodes={base_ts.num_nodes};samples={base_ts.num_samples};"
Expand All @@ -584,17 +546,16 @@ def extend(
f"Got alignments for {len(samples)} of {len(metadata_matches)} in metadata"
)

match_samples(
samples = match_samples(
date,
samples,
base_ts=base_ts,
match_db=match_db,
num_mismatches=num_mismatches,
show_progress=show_progress,
num_threads=num_threads,
precision=precision,
)

match_db.add(samples, date, num_mismatches)
match_db.create_mask_table(base_ts)
ts = increment_time(date, base_ts)

Expand Down Expand Up @@ -641,6 +602,18 @@ def update_top_level_metadata(ts, date):
return tables.tree_sequence()


def add_sample_to_tables(sample, tables, flags=tskit.NODE_IS_SAMPLE, time=0):
metadata = {
**sample.metadata,
"sc2ts": {
"qc": sample.alignment_qc,
"path": [x.asdict() for x in sample.path],
"mutations": [x.asdict() for x in sample.mutations],
},
}
return tables.nodes.add_row(flags=flags, time=time, metadata=metadata)


def match_path_ts(samples, ts, path, reversions):
"""
Given the specified list of samples with equal copying paths,
Expand All @@ -659,17 +632,7 @@ def match_path_ts(samples, ts, path, reversions):
)
for sample in samples:
assert sample.path == path
metadata = {
**sample.metadata,
"sc2ts": {
"qc": sample.alignment_qc,
"path": [x.asdict() for x in sample.path],
"mutations": [x.asdict() for x in sample.mutations],
},
}
node_id = tables.nodes.add_row(
flags=tskit.NODE_IS_SAMPLE, time=0, metadata=metadata
)
node_id = add_sample_to_tables(sample, tables)
tables.edges.add_row(0, ts.sequence_length, parent=0, child=node_id)
for mut in sample.mutations:
if mut.site_id not in site_id_map:
Expand Down Expand Up @@ -707,10 +670,10 @@ def add_exact_matches(match_db, ts, date):
for sample in samples:
assert len(sample.path) == 1
assert len(sample.mutations) == 0
node_id = tables.nodes.add_row(
node_id = add_sample_to_tables(
sample,
tables,
flags=tskit.NODE_IS_SAMPLE | core.NODE_IS_EXACT_MATCH,
time=0,
metadata=sample.metadata,
)
parent = sample.path[0].parent
logger.debug(f"ARG add exact match {sample.strain}:{node_id}->{parent}")
Expand Down Expand Up @@ -843,23 +806,21 @@ def solve_num_mismatches(ts, k):
NOTE! This is NOT taking into account the spatial distance along
the genome, and so is not a very good model in some ways.
"""
# We can match against any node in tsinfer
m = ts.num_sites
n = ts.num_nodes # We can match against any node in tsinfer
if k == 0:
# Pathological things happen when k=0
r = 1e-3
mu = 1e-20
else:
# NOTE: the magnitude of mu matters because it puts a limit
# on how low we can push the HMM precision. We should be able to solve
# for the optimal value of this parameter such that the magnitude of the
# values within the HMM are as large as possible (so that we can truncate
# usefully).
mu = 1e-3
denom = (1 - mu) ** k + (n - 1) * mu**k
r = n * mu**k / denom
assert mu < 0.5
assert r < 0.5
n = ts.num_nodes
# values of k <= 1 are not relevant for SC2 and lead to awkward corner cases
assert k > 1

# NOTE: the magnitude of mu matters because it puts a limit
# on how low we can push the HMM precision. We should be able to solve
# for the optimal value of this parameter such that the magnitude of the
# values within the HMM are as large as possible (so that we can truncate
# usefully).
# mu = 1e-2
mu = 0.125
denom = (1 - mu) ** k + (n - 1) * mu**k
r = n * mu**k / denom

# Add a little bit of extra mass for recombination so that we deterministically
# chose to recombine over k mutations
Expand Down Expand Up @@ -1312,6 +1273,8 @@ def match_tsinfer(
show_progress=False,
mirror_coordinates=False,
):
if len(samples) == 0:
return
genotypes = np.array([sample.alignment for sample in samples], dtype=np.int8).T
input_ts = ts
if mirror_coordinates:
Expand Down Expand Up @@ -1478,7 +1441,7 @@ def get_closest_mutation(node, site_id):
sample.mutations.append(
MatchMutation(
site_id=site_id,
site_position=site_pos,
site_position=int(site_pos),
derived_state=derived_state,
inherited_state=inherited_state,
is_reversion=is_reversion,
Expand Down
Loading

0 comments on commit 6956365

Please sign in to comment.