Skip to content

Commit

Permalink
Refactor and improve progress in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Sep 13, 2024
1 parent d98c162 commit 3c8cb4a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
1 change: 0 additions & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ metadata=$datadir/metadata.db
matches=$resultsdir/matches.db

dates=`python3 -m sc2ts list-dates $metadata | grep -v 2021-12-31 | head -n 14`
echo $dates

options="--num-threads $num_threads -vv -l $logfile "
# options+="--max-submission-delay $max_submission_delay "
Expand Down
57 changes: 31 additions & 26 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@


def get_progress(iterable, date, phase, show_progress, total=None):
return tqdm.tqdm(iterable, total=total, desc=f"{date}:{phase}", disable=not show_progress)
bar_format = (
"{desc:<22}{percentage:3.0f}%|{bar}"
"| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
)
return tqdm.tqdm(
iterable,
total=total,
desc=f"{date}:{phase}",
disable=not show_progress,
bar_format=bar_format,
dynamic_ncols=True,
smoothing=0.01,
unit_scale=True,
)


class TsinferProgressMonitor(tsinfer.progress.ProgressMonitor):
def __init__(self, date, title, *args, **kwargs):
def __init__(self, date, phase, *args, **kwargs):
self.date = date
self.title = title

self.phase = phase
super().__init__(*args, **kwargs)

def get(self, key, total):
bar_format = (
"{desc}{percentage:3.0f}%|{bar}"
"| {n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}{postfix}]"
self.current_instance = get_progress(
None, self.date, phase=self.phase, show_progress=self.enabled, total=total
)
self.current_instance = get_progress(None, self.date, self.title, self.enabled,
total=total)

# tqdm.tqdm(
# desc=self.desc,
# total=total,
# disable=not self.enabled,
# bar_format=bar_format,
# dynamic_ncols=True,
# smoothing=0.01,
# unit_scale=True,
# )
return self.current_instance


class MatchDb:
def __init__(self, path):
uri = f"file:{path}"
Expand Down Expand Up @@ -413,6 +413,8 @@ def match_samples(
likelihood_threshold=likelihood_threshold,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match({k})",
)

exceeding_threshold = []
Expand All @@ -437,6 +439,8 @@ def match_samples(
rho=rho,
num_threads=num_threads,
show_progress=show_progress,
date=date,
phase=f"match(F)",
)
recombinants = []
for sample in run_batch:
Expand Down Expand Up @@ -611,6 +615,7 @@ def extend(
date=date,
min_group_size=1,
show_progress=show_progress,
phase="add(close)",
)

logger.info("Looking for retrospective matches")
Expand All @@ -624,6 +629,7 @@ def extend(
min_group_size=min_group_size,
min_different_dates=3, # TODO parametrize
show_progress=show_progress,
phase="add(retro)",
)
return update_top_level_metadata(ts, date)

Expand Down Expand Up @@ -757,6 +763,7 @@ def add_matching_results(
min_group_size=1,
min_different_dates=1,
show_progress=False,
phase=None,
):
logger.info(f"Querying match DB WHERE: {where_clause}")
samples = match_db.get(where_clause)
Expand Down Expand Up @@ -786,8 +793,7 @@ def add_matching_results(

attach_nodes = []
added_samples = []

with get_progress(list(grouped_matches.items()), date, "build", show_progress) as bar:
with get_progress(list(grouped_matches.items()), date, phase, show_progress) as bar:
for (path, reversions), match_samples in bar:
different_dates = set(sample.date for sample in match_samples)
# TODO (1) add group ID from hash of samples (2) better logging of path
Expand Down Expand Up @@ -1250,6 +1256,8 @@ def match_tsinfer(
likelihood_threshold=None,
num_threads=0,
show_progress=False,
date=None,
phase=None,
mirror_coordinates=False,
):
if len(samples) == 0:
Expand All @@ -1269,11 +1277,8 @@ def match_tsinfer(
# Let's say a double break with 5 mutations is the most unlikely thing
# we're interested in solving for exactly.
likelihood_threshold = rho**2 * mu**5
# pm = tsinfer.inference._get_progress_monitor(
pm = TsinferProgressMonitor(
f"Match ({likelihood_threshold:.2g})",
show_progress,
)

pm = TsinferProgressMonitor(date, phase, enabled=show_progress)

# This is just working around tsinfer's input checking logic. The actual value
# we're incrementing by has no effect.
Expand Down

0 comments on commit 3c8cb4a

Please sign in to comment.