diff --git a/src/phylojunction/functionality/event_series.py b/src/phylojunction/functionality/event_series.py index c27c715..3781cf5 100644 --- a/src/phylojunction/functionality/event_series.py +++ b/src/phylojunction/functionality/event_series.py @@ -139,7 +139,10 @@ def populate_param_value_dict(self, param_value_dict[int(it)] = param_val_mat - # adding this time slice's dict to list + # NOTE: adding this time slice's dict to list + # append will place larger time_slice_idx at the end + # which can have different meanings depending on what + # user does! Canonically, this should mean young -> old self._time_slice_dict_list.append(param_value_dict) @property @@ -165,7 +168,8 @@ def sample_from_region_idx(self, Args: it_idx (int): Index of iteration we are looking at. This is parsed from the stochastic map table file. - time_slice_idx (int): Index of time slice. + time_slice_idx (int): Index of time slice, with 0 being + the present epoch.. potential_from_region_idx: List of indices for regions from which dispersal (range expansion) may have happened. @@ -186,7 +190,6 @@ def sample_from_region_idx(self, tot = sum(weight_list) weight_list = [i / tot for i in weight_list] - sampled_region_idx = \ choice(potential_from_region_idx, 1, @@ -477,7 +480,7 @@ def disambiguate_range_expansion(self, type. This disambiguation is necessary in cases where the source - range includes multiple atomic regions, e.g,. ABC -> ABCD. + range includes multiple atomic regions, e.g., ABC -> ABCD. We do not know if the dispersal to D came from A, B or C. Args: @@ -499,7 +502,7 @@ def disambiguate_range_expansion(self, potential_from_region_idx = \ [idx for idx, b in enumerate(bp) if b == '1'] - # from 0 to (number of time slices - 1) + # from 0 to (number of time slices - 1), with 0 being the present time_slice_idx = self._geofeat_query.find_epoch_idx(smap.age) # if the class member exists, nothing needs to be done, @@ -507,6 +510,10 @@ def disambiguate_range_expansion(self, if smap.from_region_idx == None: # sample proportional to some scheme (e.g., proportional # to FIG rate scalers, m_d) + # + # note that time_slice_idx of 0 means present! + # so inside sample_from_region_idx, we should take that + # into account! if self._from_region_sampler != None: sampled_idx = \ self._from_region_sampler.sample_from_region_idx( @@ -532,27 +539,62 @@ def disambiguate_range_expansion(self, smap.from_region_idx = \ random.choice(potential_from_region_idx) + def is_range_expansion_split_relevant(self, + from_region_idx: int, + to_region_idx: int, + range1_idxs: ty.Set[int], + range2_idxs: ty.Set[int]) -> bool: + """Determine if range expansion split-relevant. + + If the regions involved in the dispersal are + both on the same side of a range split, they are not split- + relevant. + + Args: + from_region_idx (int): + to_region_idx (int): Index of region receiving migrants. + range1_idxs (set): Set of region indices on one side of + range split. + range2_idxs (set): Set of region indices on other side of + range split. + + Returns: + (bool): Boolean for whether range expasion is + split-relevant. + """ + + if (from_region_idx in range1_idxs and to_region_idx in range2_idxs) \ + or \ + (from_region_idx in range2_idxs and to_region_idx in range1_idxs): + return True + + return False + def is_fragile_wrt_split(self, - range1_idxs: ty.Set[int], - range2_idxs: ty.Set[int], + splitting_range1_idxs: ty.Set[int], + splitting_range2_idxs: ty.Set[int], conn_graph: pjfio.GeoGraph, - expanding_range_idxs: \ - ty.Optional[ty.List[int]] = None) -> bool: + range_idxs: \ + ty.Optional[ty.List[int]] = None) \ + -> bool: """Determine if splitting range is fragile. Checks that every region in range1_idxs is in a different communicating class from every region in range2_idxs. Args: - range1_idxs (list): List of indices (int) of one of + splitting_range1_idxs (list): List of indices (int) of one of ranges resulting from the split. - range2_idxs (list): List of indices (int) of the other + splitting_range2_idxs (list): List of indices (int) of the other range resulting from the split. conn_graph (GeoGraph): Connectivity graph, with each node being a region, and each edge representing the possibility of migration between two regions (i.e., gene flow). - expanding_range_idxs (, optional). Defaults to None. + range_idxs (list, optional). List of indices + of all regions constituting a (i) expanding range, (ii) + contracting range, or (iii) contracted range. Defaults + to None. """ def find_edge_pairwise(range1_idxs: ty.Set[int], @@ -568,43 +610,61 @@ def find_edge_pairwise(range1_idxs: ty.Set[int], # classes, the range is not fragile if edge_one_way in conn_graph.edge_set or \ edge_another_way in conn_graph.edge_set: - return False + return True - return True + return False # this part of the method finds out if splitting range # is fragile at the cladogenetic event # # if there are no edges at all, the range must be fragile - if expanding_range_idxs is None: + if range_idxs is None: if len(conn_graph.edge_set) == 0: return True # if there are edges, we check pairwise - is_fragile = \ - find_edge_pairwise(range1_idxs, - range2_idxs, - conn_graph) + at_least_one_edge = \ + find_edge_pairwise(splitting_range1_idxs, + splitting_range2_idxs, + conn_graph) + is_fragile = not at_least_one_edge + return is_fragile # this part of the method is used to determine if # a range prior to a dispersal is already fragile # with respect to a splitting event happening in # the future along this branch - else: - expanding_in_range1_idx = set([]) - expanding_in_range2_idx = set([]) + elif range_idxs is not None: + range1_idx = set([]) + range2_idx = set([]) + + for region_idx in range_idxs: + if region_idx in splitting_range1_idxs: + range1_idx.add(region_idx) + + elif region_idx in splitting_range2_idxs: + range1_idx.add(region_idx) + + # at least one region on both sides of the split + # must be occupied for a range to be considered + # 'previously fragile' + if len(range1_idx) == 0 or \ + len(range2_idx) == 0: + + is_fragile = False + + return is_fragile + + at_least_one_edge = \ + find_edge_pairwise(range1_idx, + range2_idx, + conn_graph) - for region_idx in expanding_range_idxs: - if region_idx in range1_idxs: - expanding_in_range1_idx.add(region_idx) + is_fragile = not at_least_one_edge - elif region_idx in range2_idxs: - expanding_in_range1_idx.add(region_idx) + return is_fragile - return find_edge_pairwise(expanding_in_range1_idx, - expanding_in_range2_idx, - conn_graph) def initialize_event_series_dict(self) -> None: @@ -658,32 +718,39 @@ def recursively_populate_event_series_dict(nd: dp.Node, self._geofeat_query. \ conn_graph_list[clado_smap_time_slice_idx] - ch1_bp = clado_smap.to_state_bit_patt # child 1 - ch2_bp = clado_smap.to_state2_bit_patt # child 2 + # child 1 + ch1_bp = clado_smap.to_state_bit_patt + # child 2 (will be None if no range split at speciation) + ch2_bp = clado_smap.to_state2_bit_patt # get sets of region indices for the two mutually # exclusive ranges ch1_set = \ set([idx for idx, b in enumerate(ch1_bp) if b == "1"]) - ch2_set = \ - set([idx for idx, b in enumerate(ch2_bp) if b == "1"]) + + ch2_set = set() + if ch2_bp: + ch2_set = \ + set([idx for idx, b in enumerate(ch2_bp) if b == "1"]) if not ch1_set.isdisjoint(ch2_set): exit(("Error during truncation of event series. " "At range split, the two mutually exclusive ranges shared" " regions. Exiting...")) - # annotate cladogenetic stochastic map depending on - # whether splitting range is or not fragile at splitting - # moment - splitting_range_is_fragile = \ - self.is_fragile_wrt_split(ch1_set, - ch2_set, - clado_conn_graph) - clado_smap.splitting_range_fragile = \ - splitting_range_is_fragile - # only speciation events with range splitting if clado_smap.to_state2_bit_patt is not None: + # cladogenetic + # + # annotate cladogenetic stochastic map depending on + # whether splitting range is or not fragile at splitting + # moment + splitting_range_is_fragile = \ + self.is_fragile_wrt_split(ch1_set, + ch2_set, + clado_conn_graph) + clado_smap.splitting_range_fragile = \ + splitting_range_is_fragile + # anagenetic # NOTE: assumes stochastic maps are sorted in chronological order!!! # (old first, young later) @@ -705,6 +772,41 @@ def recursively_populate_event_series_dict(nd: dp.Node, self._geofeat_query.\ conn_graph_list[ana_smap_time_slice_idx] + if ana_smap.map_type == "contraction": + # check if range before contraction + # was unstable (and annotate ana_smap + # accordingly) + pre_contr_range_bp = ana_smap.from_state_bit_patt + pre_contr_range_set = \ + set([idx for idx, b \ + in enumerate(pre_contr_range_bp) if b == "1"]) + pre_contr_range_fragile = \ + self.is_fragile_wrt_split(ch1_set, + ch2_set, + ana_conn_graph, + pre_contr_range_set) + ana_smap.previously_fragile_wrt_split = \ + pre_contr_range_fragile + + # check if range after contraction + # was unstable (and annotate ana_smap + # accordingly) + post_contr_range_bp = ana_smap.to_state_bit_patt + post_contr_range_set = \ + set([idx for idx, b \ + in enumerate(post_contr_range_bp) if b == "1"]) + post_contr_range_fragile = \ + self.is_fragile_wrt_split(ch1_set, + ch2_set, + ana_conn_graph, + post_contr_range_set) + ana_smap.fragile_after_contr_wrt_split = \ + post_contr_range_fragile + + # NOTE: + # if range is stable before the contraction and unstable + # after, then this is 'speciation by extinction' + if ana_smap.map_type == "expansion": # disambiguate the source region # (this updates from_region_idx inside stoch map) @@ -714,6 +816,16 @@ def recursively_populate_event_series_dict(nd: dp.Node, from_region_idx = ana_smap.from_region_idx to_region_idx = ana_smap.region_gained_idx + # check if dispersal is split-relevant + is_split_relevant = \ + self.is_range_expansion_split_relevant( + from_region_idx, + to_region_idx, + ch1_set, + ch2_set + ) + ana_smap.split_relevant = is_split_relevant + # check if regions involved in dispersals # belong to the same communication class # (i.e., there is a path of connectivity @@ -736,15 +848,13 @@ def recursively_populate_event_series_dict(nd: dp.Node, if not self._directed_edges: over_barrier = \ over_barrier or \ - (to_region_idx, from_region_idx) in \ + (to_region_idx, from_region_idx) not in \ ana_conn_graph.edge_set ana_smap.over_barrier = over_barrier - # TODO: see if from_region_bit_patt is - # fragile, and annotate ana_smap as - # split_relevant_fragile_pre_dispersal = True - # will need parent information to determine + # see if pre-dispersal range is fragile, + # and annotate ana_smap accordingly expanding_range_bp = ana_smap.from_state_bit_patt expanding_range_set = set([idx for idx, b \ in enumerate(expanding_range_bp) if b == "1"]) @@ -784,17 +894,6 @@ def recursively_populate_event_series_dict(nd: dp.Node, recursively_populate_event_series_dict(seed_nd, it_idx) - # debugging - # for nd_label, it_event_series_dict in self._event_series_dict.items(): - # # if nd_label in ("nd5"): - # for it_idx, event_series in it_event_series_dict.items(): - # if isinstance(event_series, EvolRelevantEventSeries): - # if it_idx == 1: - # print(nd_label) - # for ev in event_series.event_list: - # print(ev) - # print("\n") - def add_paleogeo_event_series_dict(self, geo_cond_name: str, @@ -809,9 +908,9 @@ def add_paleogeo_event_series_dict(self, about edge direction or not. Defaults to 'False'. """ - cond_change_epoch_start_ages_mat = self._geofeat_query.\ + conn_gained_epoch_start_ages_mat = self._geofeat_query.\ geo_cond_change_times_dict[geo_cond_name] - cond_change_back_epoch_start_ages_mat = self._geofeat_query.\ + conn_lost_epoch_start_ages_mat = self._geofeat_query.\ geo_cond_change_back_times_dict[geo_cond_name] for nd_label, it_event_series_dict in self._event_series_dict.items(): @@ -839,9 +938,10 @@ def add_paleogeo_event_series_dict(self, # update event_list, going over connectivity graph for # each epoch, old to young - for epoch_idx in range(self._geofeat_query.n_time_slices): - epoch_start_age = self._geofeat_query.feat_coll.\ - epoch_age_start_list_old2young[epoch_idx] + # for epoch_idx in range(self._geofeat_query.n_time_slices): + for epoch_idx in reversed(range(self._geofeat_query.n_time_slices)): + epoch_start_age = self._geofeat_query.feat_coll. \ + epoch_age_start_list_young2old[epoch_idx] # if epoch starts after the range split at speciation, # we do not care about it @@ -871,7 +971,7 @@ def add_paleogeo_event_series_dict(self, # _populate_conn_graph_list() in GeoFeatureQuery, # which assumes 1 means connectivity!) if epoch_start_age in \ - cond_change_epoch_start_ages_mat\ + conn_gained_epoch_start_ages_mat\ [from_region_idx][to_region_idx]: bd = pjfio.BarrierDisappearance( self._n_char, @@ -891,7 +991,7 @@ def add_paleogeo_event_series_dict(self, # and if barrier appeared in this epoch if epoch_start_age in \ - cond_change_back_epoch_start_ages_mat\ + conn_lost_epoch_start_ages_mat\ [from_region_idx][to_region_idx]: ba = pjfio.BarrierAppearance( self._n_char, @@ -905,16 +1005,6 @@ def add_paleogeo_event_series_dict(self, # EvolRelevantEvent) insort(event_list, ba) - for nd_label, it_event_series_dict in self._event_series_dict.items(): - # if nd_label in ("nd5"): - for it_idx, event_series in it_event_series_dict.items(): - if isinstance(event_series, EvolRelevantEventSeries): - if it_idx == 1: - print(nd_label) - for ev in event_series.event_list: - print(ev) - print("\n") - def initialize_truncated_event_series_dict(self) -> None: """Populate _truncated_event_series_dict @@ -926,10 +1016,50 @@ def initialize_truncated_event_series_dict(self) -> None: stable range. """ - for nd_label, it_event_series_dict in self._event_series_dict.items(): - # rda = RegionDispersalAncestry() + # truncated event series data structure + # + # it splitting_nd_idx event_abbrev_csv_list hyp + # 1 5 b+_0_1,d_0_1_o_r,b-_0_1 vicariance + # 2 5 ... founder-event + for nd_label, it_event_series_dict in self._event_series_dict.items(): for it_idx, event_series in it_event_series_dict.items(): + # all events, paleogeographic and biogeographic + event_list = event_series.event_list + + # we make deep copy because the annotation we carry out below + # (.split_relevant) will be different for each node in the tree, + # for the same dispersal event + trunc_event_list = [copy.deepcopy(event) for event in event_list] + + # last event should be a cladogenetic split + clado_smap = event_list[-1] + + ############################################### + # First rule: the range being split is stable # + # # + # Classification: Ambiguous # + ############################################### + if not clado_smap.splitting_range_fragile: + # TODO: + # add code for producing truncated event series + # (make member in event series with hypothesis + # it supports), and annotate it as ambiguous + pass + + assert (len(trunc_event_list) > 1) + for ev_idx, ev in enumerate(event_list): + # TODO: + # if isinstance(ev, pjfio.BarrierDisappearance) and \ + # if barrier disappearance is split relevant + # truncate at youngest stable range! + trunc_event_list = trunc_event_list[ev_idx:] + + + + + + # event_series will be an empty dictionary if range # did not split at node -- we do not care about those! # (even if subtending branch has maps!) @@ -1061,6 +1191,10 @@ def tabulate_hyp_support(self) -> None: pass + @property + def event_series_dict(self) -> ty.Dict[str, ty.Dict[int, EvolRelevantEventSeries]]: + return self._event_series_dict + @property def hyp_support_dict(self) -> ty.Dict[Hypothesis, int]: return self._hyp_support_dict @@ -1077,8 +1211,8 @@ def _is_split_relevant_event(): if __name__ == "__main__": - n_chars = 2 - # n_chars = 4 + # n_chars = 2 + n_chars = 4 state2bit_lookup = pjbio.State2BitLookup(n_chars, 2, geosse=True) @@ -1101,34 +1235,36 @@ def _is_split_relevant_event(): # '0111': 13 # '1111': 14 - # ann_tr_list = [pjr.read_nwk_tree_str("examples/trees_maps_files/geosse_dummy_tree3.tre", - ann_tr_list = [pjr.read_nwk_tree_str("examples/trees_maps_files/geosse_dummy_tree2.tre", + # ann_tr_list = [pjr.read_nwk_tree_str("examples/trees_maps_files/geosse_dummy_tree2.tre", + ann_tr_list = [pjr.read_nwk_tree_str("examples/trees_maps_files/geosse_dummy_tree3.tre", "read_tree", node_names_attribute="index", n_states=n_states, in_file=True)] - # node_states_file_path = "examples/trees_maps_files/geosse_dummy_tree3_tip_states.tsv" - # pjsmap.StochMapsOnTreeCollection("examples/trees_maps_files/geosse_dummy_tree3_maps.tsv", + # node_states_file_path = "examples/trees_maps_files/geosse_dummy_tree2_tip_states.tsv" + # pjsmap.StochMapsOnTreeCollection("examples/trees_maps_files/geosse_dummy_tree2_maps.tsv", smap_coll = \ - pjsmap.StochMapsOnTreeCollection("examples/trees_maps_files/geosse_dummy_tree2_maps.tsv", + pjsmap.StochMapsOnTreeCollection("examples/trees_maps_files/geosse_dummy_tree3_maps.tsv", ann_tr_list, state2bit_lookup, - node_states_file_path="examples/trees_maps_files/geosse_dummy_tree2_tip_states.tsv", + node_states_file_path="examples/trees_maps_files/geosse_dummy_tree3_tip_states.tsv", stoch_map_attr_name="state") + # param_log_dir = "examples/feature_files/two_regions_feature_set_event_series" + param_log_dir = "examples/feature_files/four_regions_feature_set_event_series" frs = FromRegionSampler( n_chars, - "examples/feature_files/two_regions_feature_set_event_series", - "epoch_", + param_log_dir, + "epoch_age_", "_rel_rates", "m_d" ) - feature_summary_fp = "examples/feature_files/two_regions_feature_set_event_series/feature_summary.csv" - age_summary_fp = "examples/feature_files/two_regions_feature_set_event_series/age_summary.csv" - # feature_summary_fp = "examples/feature_files/four_regions_feature_set_event_series/feature_summary.csv" - # age_summary_fp = "examples/feature_files/four_regions_feature_set_event_series/age_summary.csv" + # feature_summary_fp = "examples/feature_files/two_regions_feature_set_event_series/feature_summary.csv" + # age_summary_fp = "examples/feature_files/two_regions_feature_set_event_series/age_summary.csv" + feature_summary_fp = "examples/feature_files/four_regions_feature_set_event_series/feature_summary.csv" + age_summary_fp = "examples/feature_files/four_regions_feature_set_event_series/age_summary.csv" fc = pjfio.GeoFeatureCollection( feature_summary_fp, @@ -1142,10 +1278,22 @@ def _is_split_relevant_event(): # all members, including graph, are populated here! fq.populate_geo_cond_member_dicts("land_bridge", requirement_fn) - event_series_tabulator = \ + est = \ EvolRelevantEventSeriesTabulator( ann_tr_list, smap_coll, fq, from_region_sampler=frs - ) \ No newline at end of file + ) + + # looking at things + it_to_look_at = [1] + for nd_label, it_event_series_dict in est.event_series_dict.items(): + # if nd_label in ("nd5"): + for it_idx, event_series in it_event_series_dict.items(): + if isinstance(event_series, EvolRelevantEventSeries): + if it_idx in it_to_look_at: + print(nd_label) + for ev in event_series.event_list: + print(ev) + print("\n") \ No newline at end of file diff --git a/src/phylojunction/functionality/feature_io.py b/src/phylojunction/functionality/feature_io.py index ad5ce46..5e60899 100644 --- a/src/phylojunction/functionality/feature_io.py +++ b/src/phylojunction/functionality/feature_io.py @@ -1223,9 +1223,9 @@ def populate_geo_cond_member_dicts( self.geo_cond_bit_dict[geo_cond_name] = requirement_fn - self._populate_oldest_geo_cond_bit_dict(geo_cond_name) - # DEPRECATED + # self._populate_oldest_geo_cond_bit_dict(geo_cond_name) + self._populate_geo_cond_change_times_dict(geo_cond_name) self._populate_conn_graph_list(geo_cond_name, is_directed) @@ -1264,8 +1264,14 @@ def _populate_geo_cond_change_times_dict(self, # between if type(region1_str_or_list) == list: - # old to young bit + # region2_str is a bit that will be in the order + # epochs are specified in the age summary file (which + # should be time index 1 -- top of file -- for the youngest + # epoch, and a larger time index for oldest epochs) for region2_str in region1_str_or_list: + # make it old -> young + region2_str = region2_str[::-1] + # scanning the geographic condition bit pattern # (sliding window of size 2) for a '01' string # which indicates when the geographic condition @@ -1275,6 +1281,7 @@ def _populate_geo_cond_change_times_dict(self, # we get the position in the bit pattern where # we had the departing '0', and use that position # to grab the start of the epoch with the '1' + # [age_starts_young2old[idx + 1] for idx, (i, j) region2_times_list = \ [age_starts_old2young[idx + 1] for idx, (i, j) \ in enumerate(zip(region2_str, @@ -1282,6 +1289,7 @@ def _populate_geo_cond_change_times_dict(self, if (i, j) == ('0', '1')] # now doing change back times ('10') + # [age_starts_young2old[idx + 1] for idx, (i, j) region2_back_times_list = \ [age_starts_old2young[idx + 1] for idx, (i, j) \ in enumerate(zip(region2_str, @@ -1347,7 +1355,8 @@ def _populate_oldest_geo_cond_bit_dict(self, geo_cond_name: str) -> None: # between if type(region1_str_or_list) == list: for region2_str in region1_str_or_list: - region2_oldest_bit = region2_str[0] + # region2_oldest_bit = region2_str[0] + region2_oldest_bit = region2_str[-1] region1_oldest_bit_list.append(region2_oldest_bit) @@ -1355,8 +1364,9 @@ def _populate_oldest_geo_cond_bit_dict(self, geo_cond_name: str) -> None: # within else: - # old to young bit - region1_oldest_bit_list = region1_str_or_list[0] + # time directionality is given by user + # region1_oldest_bit_list = region1_str_or_list[0] + region1_oldest_bit_list = region1_str_or_list[-1] self.geo_oldest_cond_bit_dict[geo_cond_name].append(region1_oldest_bit_list) @@ -1384,6 +1394,15 @@ def _populate_conn_graph_list(self, geo_cond_name: str, is_directed: bool=False) self.conn_graph_list.append(g) def find_epoch_idx(self, an_age: float) -> int: + """ + + Args: + an_age (float): The age of an event. + + Returns: + (int): Index of epoch containing 'an_age' (with 0 being + present). + """ time_slice_index = 0 for time_slice_index in range(0, self.n_time_slices): diff --git a/src/phylojunction/functionality/stoch_map.py b/src/phylojunction/functionality/stoch_map.py index 6eb9b5a..f5d76e7 100644 --- a/src/phylojunction/functionality/stoch_map.py +++ b/src/phylojunction/functionality/stoch_map.py @@ -237,7 +237,6 @@ def _init_str_representation(self) -> None: + "\n To state " + str(self.to_state) + ", bits \'" \ + self.to_state_bit_patt + "\', " \ + "range size " + str(self.size_of_final_range) \ - + "\n Dispersal from region (index): " + str(self.from_region_idx) \ + "\n Region gained (index): " + str(self.region_gained_idx)) @property @@ -304,7 +303,7 @@ def _update_str_representation_previously_fragile(self) -> None: def _update_str_representation_split_relevant(self) -> None: self.str_representation += \ - + "\n Dispersal relevant to split: " + str(self._split_relevant) \ + "\n Dispersal relevant to split: " + str(self._split_relevant) \ # setter def gained_region_is_lost_in_future(self): @@ -336,6 +335,9 @@ class RangeContraction(StochMap): size_of_final_range: int region_lost_idx: int + _previously_fragile_wrt_split: ty.Optional[bool] + _fragile_after_contr_wrt_split: ty.Optional[bool] + def __init__(self, region_lost_idx: int, n_regions: int, @@ -367,8 +369,29 @@ def __init__(self, self.size_of_final_range = \ sum(int(i) for i in to_state_bit_patt) + self._previously_fragile_wrt_split = None + self._fragile_after_contr_wrt_split = None + self._init_str_representation() + @property + def previously_fragile_wrt_split(self) -> bool: + return self._previously_fragile_wrt_split + + @property + def fragile_after_contr_wrt_split(self) -> bool: + return self._fragile_after_contr_wrt_split + + @previously_fragile_wrt_split.setter + def previously_fragile_wrt_split(self, val: bool) -> None: + self._previously_fragile_wrt_split = val + self._update_str_representation_previously_fragile() + + @fragile_after_contr_wrt_split.setter + def fragile_after_contr_wrt_split(self, val: bool) -> None: + self._fragile_after_contr_wrt_split = val + self._update_str_representation_fragile_after_contr() + def _init_str_representation(self) -> None: self.str_representation = \ "Range contraction / Local extinction (at age = " + str(self.age) + ")" \ @@ -381,6 +404,16 @@ def _init_str_representation(self) -> None: + "range size " + str(self.size_of_final_range) \ + "\n Region lost (index): " + str(self.region_lost_idx) + def _update_str_representation_previously_fragile(self) -> None: + self.str_representation += \ + "\n Contracting range previously fragile (w.r.t. split): " \ + + str(self._previously_fragile_wrt_split) + + def _update_str_representation_fragile_after_contr(self) -> None: + self.str_representation += \ + "\n Contracted range is fragile (w.r.t. split): " \ + + str(self._fragile_after_contr_wrt_split) + def __str__(self) -> str: return self.str_representation @@ -911,14 +944,13 @@ def update_tree_attributes(self, This function is called outside of StochMapsOnTree. It goes through the two stochastic maps dictionaries, and updates the AnnotatedTree member with respect to - its (i) node's attributes, and (ii) node_attr_dict member + its (i) node's attributes, and (ii) node_attr_dict member. + + This function has no return and only side-effects. Args: stoch_map_attr_name (str): Name of the node attribute the stochastic maps carry a value for (e.g., 'state') - - Returns: - None, this function has a side-effect """ # debugging @@ -1264,6 +1296,8 @@ def _read_stoch_maps_file(self, sorted(self.sorted_it_idxs) + print("self.stoch_maps_tree_dict[1]", self.stoch_maps_tree_dict[1]) + ############################################# # Make every SMOT update its tree's members # # according to the stochastic maps of their # diff --git a/tests/functionality/test_feature_io.py b/tests/functionality/test_feature_io.py index 3ec15ba..f669560 100644 --- a/tests/functionality/test_feature_io.py +++ b/tests/functionality/test_feature_io.py @@ -38,7 +38,7 @@ def setUpClass(cls) -> None: cls.two_cb_is_1_requirement_fn = \ pjgeo.GeoFeatureQuery.cb_feature_equals_value_is_connected(cls.two_region_geo_coll, - 0, + 1, feat_name="cb_1") cls.four_cb_is_1_requirement_fn = \ @@ -59,17 +59,17 @@ def test_read_features_epoch_ages_two_regions(self) -> None: self.assertEqual(self.two_region_geo_coll.epoch_age_end_list_old2young, [4.25, 1.5, 0.0]) self.assertEqual(self.two_region_geo_coll.epoch_age_start_list_young2old, - [1.5, 4.25, -math.inf]) + [1.5, 4.25, math.inf]) self.assertEqual(self.two_region_geo_coll.epoch_age_start_list_old2young, - [-math.inf, 4.25, 1.5]) + [math.inf, 4.25, 1.5]) self.assertEqual(self.two_region_geo_coll.epoch_mid_age_list_young2old, - [0.75, 2.875, -math.inf]) + [0.75, 2.875, math.inf]) self.assertEqual(self.two_region_geo_coll.epoch_mid_age_list_old2young, - [-math.inf, 2.875, 0.75]) + [math.inf, 2.875, 0.75]) self.assertEqual(self.two_region_geo_coll.epoch_age_start_list_young2old, - [1.5, 4.25, -math.inf]) + [1.5, 4.25, math.inf]) self.assertEqual(self.two_region_geo_coll.epoch_age_start_list_old2young, - [-math.inf, 4.25, 1.5]) + [math.inf, 4.25, 1.5]) def test_read_features_epoch_ages_four_regions(self) -> None: """ @@ -87,16 +87,16 @@ def test_read_features_epoch_ages_four_regions(self) -> None: [20.0, 10.0, 0.0]) self.assertEqual( self.four_region_geo_coll.epoch_age_start_list_young2old, - [10.0, 20.0, -math.inf]) + [10.0, 20.0, math.inf]) self.assertEqual( self.four_region_geo_coll.epoch_age_start_list_old2young, - [-math.inf, 20.0, 10.0]) + [math.inf, 20.0, 10.0]) self.assertEqual( self.four_region_geo_coll.epoch_mid_age_list_young2old, - [5.0, 15.0, -math.inf]) + [5.0, 15.0, math.inf]) self.assertEqual( self.four_region_geo_coll.epoch_mid_age_list_old2young, - [-math.inf, 15.0, 5.0]) + [math.inf, 15.0, 5.0]) def test_epoch_find(self) -> None: """Test that auxiliary method finds the epoch a time belongs. @@ -168,11 +168,11 @@ def test_read_features_values_and_types_two_regions(self) -> None: # from feature_summary.csv, epoch 1 is youngest, epoch 3 is oldest exp1 = ('Feature (cb_1) | between-categorical 1 | Epoch 1 ' - '| 2 regions\n[[0 0]\n [1 0]]') - exp2 = ('Feature (cb_1) | between-categorical 1 | Epoch 2 ' '| 2 regions\n[[0 1]\n [1 0]]') + exp2 = ('Feature (cb_1) | between-categorical 1 | Epoch 2 ' + '| 2 regions\n[[0 0]\n [0 0]]') exp3 = ('Feature (cb_1) | between-categorical 1 | Epoch 3 ' - '| 2 regions\n[[0 1]\n [0 0]]') + '| 2 regions\n[[0 1]\n [1 0]]') self.assertEqual(str(self.two_region_geo_coll.feat_name_epochs_dict['cb_1'][1]), exp1) @@ -220,27 +220,30 @@ def test_geo_query_two_regions(self) -> None: # A -> A, A -> B, B -> A, B -> B # each bit is for an epoch, from young to old (see feature_summary.csv) geo_cond_bit_patterns = " ".join(i for i in geo_cond_bit_patterns_list) - self.assertEqual(geo_cond_bit_patterns, "111 100 001 111") + self.assertEqual(geo_cond_bit_patterns, "000 101 101 000") + + # self.assertEqual( + # self.two_geo_query.geo_oldest_cond_bit_dict, + # {'land_bridge': [['0', '1'], ['0', '1']]} + # ) + # A -> A, A -> B, B -> A, B -> B + # when connectivity exists! self.assertEqual( - self.two_geo_query.geo_oldest_cond_bit_dict, - {'land_bridge': [['1', '1'], ['0', '1']]} + self.two_geo_query.get_geo_condition_change_times("land_bridge"), + [[[], [1.5]], [[1.5], []]] ) - # self.assertEqual( - # self.two_geo_query.get_geo_condition_change_times("land_bridge"), - # [[[], []], [[1.5], []]] - # ) - # - # self.assertEqual( - # self.two_geo_query.get_geo_condition_change_back_times("land_bridge"), - # [[[], [4.25]], [[], []]] - # ) + # when connectivity does not exist! + self.assertEqual( + self.two_geo_query.get_geo_condition_change_back_times("land_bridge"), + [[[], [4.25]], [[4.25], []]] + ) # epoch 1, idx = 0 (youngest, the index comes from feat_summary.csv) self.assertSetEqual( self.two_geo_query.conn_graph_list[0].edge_set, - {(0, 1)} + {(0, 1), (1, 0)} ) # epoch 2, idx = 1 (no edges!) @@ -252,7 +255,7 @@ def test_geo_query_two_regions(self) -> None: # epoch 3, idx = 2 (oldest, the index comes from feat_summary.csv) self.assertSetEqual( self.two_geo_query.conn_graph_list[2].edge_set, - {(1, 0)} + {(0, 1), (1, 0)} ) self.assertEqual(self.two_geo_query.get_comm_classes(0.1),