diff --git a/hstrat/_auxiliary_lib/_alifestd_prune_extinct_lineages_asexual.py b/hstrat/_auxiliary_lib/_alifestd_prune_extinct_lineages_asexual.py index bd8f70ad..d666a00e 100644 --- a/hstrat/_auxiliary_lib/_alifestd_prune_extinct_lineages_asexual.py +++ b/hstrat/_auxiliary_lib/_alifestd_prune_extinct_lineages_asexual.py @@ -4,6 +4,7 @@ import pandas as pd from ._alifestd_has_contiguous_ids import alifestd_has_contiguous_ids +from ._alifestd_is_topologically_sorted import alifestd_is_topologically_sorted from ._alifestd_try_add_ancestor_id_col import alifestd_try_add_ancestor_id_col from ._alifestd_unfurl_lineage_asexual import alifestd_unfurl_lineage_asexual from ._jit import jit @@ -54,6 +55,20 @@ def _create_has_extant_descendant_contiguous( return has_extant_descendant +@jit(nopython=True) +def _create_has_extant_descendant_contiguous_sorted( + ancestor_ids: np.ndarray, + extant_mask: np.ndarray, +) -> np.ndarray: + """Implementation detail for alifestd_prune_extinct_lineages_asexual.""" + + has_extant_descendant = extant_mask.copy() + for id_ in range(len(ancestor_ids) - 1, -1, -1): + has_extant_descendant[ancestor_ids[id_]] |= has_extant_descendant[id_] + + return has_extant_descendant + + def alifestd_prune_extinct_lineages_asexual( phylogeny_df: pd.DataFrame, mutate: bool = False, @@ -92,7 +107,10 @@ def alifestd_prune_extinct_lineages_asexual( phylogeny_df = phylogeny_df.copy() phylogeny_df = alifestd_try_add_ancestor_id_col(phylogeny_df, mutate=True) - phylogeny_df.set_index("id", drop=False, inplace=True) + if alifestd_has_contiguous_ids(phylogeny_df): + phylogeny_df.reset_index(drop=True, inplace=True) + else: + phylogeny_df.index = phylogeny_df["id"] extant_mask = None if "extant" in phylogeny_df: @@ -105,15 +123,22 @@ def alifestd_prune_extinct_lineages_asexual( else: raise ValueError('Need "extant" or "destruction_time" column.') - if alifestd_has_contiguous_ids(phylogeny_df): + if not alifestd_has_contiguous_ids(phylogeny_df): + has_extant_descendant = _create_has_extant_descendant_noncontiguous( + phylogeny_df, + extant_mask, + ) + elif not alifestd_is_topologically_sorted(phylogeny_df): has_extant_descendant = _create_has_extant_descendant_contiguous( phylogeny_df["ancestor_id"].to_numpy(dtype=np.uint64), extant_mask.to_numpy(dtype=bool), ) else: - has_extant_descendant = _create_has_extant_descendant_noncontiguous( - phylogeny_df, - extant_mask, + has_extant_descendant = ( + _create_has_extant_descendant_contiguous_sorted( + phylogeny_df["ancestor_id"].to_numpy(dtype=np.uint64), + extant_mask.to_numpy(dtype=bool), + ) ) phylogeny_df = phylogeny_df[has_extant_descendant].reset_index(drop=True)