diff --git a/sc2ts/utils.py b/sc2ts/utils.py index 6328ce9..5b58344 100644 --- a/sc2ts/utils.py +++ b/sc2ts/utils.py @@ -26,6 +26,16 @@ from . import core from . import lineages +def pairwise(iterable): + """ + s -> (s0,s1), (s1,s2), (s2, s3), ... + We can replace this with itertools.pairwise once using min Python 3.10 + """ + + a, b = itertools.tee(iterable) + next(b, None) + return zip(a, b) + @numba.njit def _get_root_path(parent, node): @@ -247,6 +257,15 @@ def is_parent_lineage_consistent(self): == arg.parent_imputed_lineages ) + def fwd_bkwd_hmm_parents(self): + """ + Returns successive tuples of (forward parent 1, backward parent 1), etc + """ + for parents in itertools.zip_longest( + self.hmm_runs["forward"].parents, self.hmm_runs["backward"].parents + ): + yield np.array(parents) + def asdict(self): return dataclasses.asdict(self) @@ -806,13 +825,30 @@ def export_recombinant_breakpoints(self): break, and their MRCA. Recombinants with multiple breaks are represented by multiple rows. - You can only list the breakpoints in recombination nodes that have 2 parents by + You can list the breakpoints in recombination nodes that have only 2 parents by doing e.g. df.drop_duplicates('node', keep=False) """ recombs = self.combine_recombinant_info() data = [] + parents = set() + # much quicker to get all the haplotype calculations in a single sweep: + for rec in recombs: + if rec.is_path_length_consistent(): + parents.update(u for fwd_bkwd in rec.fwd_bkwd_hmm_parents() for u in fwd_bkwd) + parents = {k: index for index, k in enumerate(parents)} # Convert to a mapping + H = self.ts.genotype_matrix(samples=list(parents.keys()), isolated_as_missing=False).T + for rec in recombs: arg = rec.arg_info + fwd_bck_parents_max_dist = np.full(len(arg.parents) -1, np.nan) + if rec.is_path_length_consistent(): + # Calculate the sequence difference between fwd and bkwd parents + seq_diff = [ + np.sum(H[parents[fwd], :] != H[parents[bwd], :]) + for fwd, bwd in rec.fwd_bkwd_hmm_parents() + ] + fwd_bck_parents_max_dist = np.array( + [max(parent_pair) for parent_pair in pairwise(seq_diff)], dtype=int) for j in range(len(arg.parents) - 1): row = rec.data_summary() mrca = arg.mrcas[j] @@ -825,6 +861,7 @@ def export_recombinant_breakpoints(self): row["right_parent_imputed_lineage"] = arg.parent_imputed_lineages[j + 1] row["mrca"] = mrca row["mrca_date"] = self.nodes_date[mrca] + row["fwd_bck_parents_max_dist"] = fwd_bck_parents_max_dist[j] data.append(row) return pd.DataFrame(sorted(data, key=operator.itemgetter("node")))