Skip to content

Commit

Permalink
Add code for paleogeo event annotation, update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
binho authored and binho committed Apr 2, 2024
1 parent 4770e69 commit 2dd502b
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 102 deletions.
183 changes: 150 additions & 33 deletions src/phylojunction/functionality/event_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ def __init__(self,
if verbose:
print(" ... finished reading paleogeographic events.")

# annotate paleogeographic events
self.annotate_paleogeo_events()

if verbose:
print(" ... finished annotating paleogeographic events.")

self.populate_trunc_event_series()

if verbose:
Expand Down Expand Up @@ -803,7 +809,7 @@ def parse_range_contraction(ch1_set,
# if range is stable before the contraction and unstable
# after, then this is 'speciation by extinction'

def do_anagenetic_smap(ch1_set, ch2_set, ana_smap):
def do_anagenetic_smap(ch1_set, ch2_set, ana_smap) -> None:
# gathering initial information to annotate
# anagenetic stochastic map

Expand Down Expand Up @@ -844,11 +850,11 @@ def do_anagenetic_smap(ch1_set, ch2_set, ana_smap):
ana_conn_graph)

def recursively_populate_event_series_dict(nd: dp.Node,
it_idx: int):
it_idx: int) -> None:
"""Populate self._event_series_dict recursively.
Disambiguation of range expansion (dispersal) events
happens here.
Leaf nodes are ignored. Disambiguation of range expansion
(dispersal) events happens here.
"""

# do current node
Expand Down Expand Up @@ -897,18 +903,20 @@ def recursively_populate_event_series_dict(nd: dp.Node,
conn_graph_list[clado_smap_time_slice_idx]

# child 1
ch1_bp = clado_smap.to_state_bit_patt
# ch1_bp = clado_smap.to_state_bit_patt
# child 2 (will be None if no range split at speciation)
ch2_bp = clado_smap.to_state2_bit_patt
# ch2_bp = clado_smap.to_state2_bit_patt
# get sets of region indices for the two mutually
# exclusive ranges
ch1_set = \
set([idx for idx, b in enumerate(ch1_bp) if b == "1"])
ch1_set = clado_smap.to_state_idx_set
# ch1_set = \
# set([idx for idx, b in enumerate(ch1_bp) if b == "1"])

ch2_set = ch1_set
if ch2_bp is not None:
ch2_set = \
set([idx for idx, b in enumerate(ch2_bp) if b == "1"])
ch2_set = clado_smap.to_state2_idx_set
# ch2_set = ch1_set
# if ch2_bp is not None:
# ch2_set = \
# set([idx for idx, b in enumerate(ch2_bp) if b == "1"])

# only speciation events with range splitting
# if clado_smap.to_state2_bit_patt is not None:
Expand Down Expand Up @@ -969,16 +977,17 @@ def recursively_populate_event_series_dict(nd: dp.Node,

# iterating over each MCMC iteration when stochastic maps were logged
for it_idx, smap in self._smap_collection.stoch_maps_tree_dict.items():
ann_tr = smap.ann_tr
seed_nd = smap.ann_tr.origin_node if smap.ann_tr.with_origin \
else smap.ann_tr.root_node

# populate self._self._event_series_dict
# (terminal nodes are ignored)
#
# key1: node name
# value1 iteration dict:
# key2: iteration index
# value2 event series object
# value1 iteration dict
# key2: iteration index
# value2 event series object (or empty dictionary if internal
# node but no events, like origin node)
recursively_populate_event_series_dict(seed_nd,
it_idx)

Expand Down Expand Up @@ -1036,9 +1045,9 @@ def insort_barrier(from_region_idx,

for nd_label, it_event_series_dict in self._event_series_dict.items():
for it_idx, event_series in it_event_series_dict.items():
# event_series will be an empty dictionary if range
# did not split at node -- we do not care about those!
# (even if subtending branch has maps!)
# node at the start of process may have an empty event series
# if stoch mapping file for some reason did not have an entry
# those nodes
if isinstance(event_series, EvolRelevantEventSeries):
event_list = event_series.event_list

Expand Down Expand Up @@ -1080,6 +1089,7 @@ def insort_barrier(from_region_idx,
else range((from_region_idx+1), n_regions)

for to_region_idx in inner_loop_idxs:
# ignore diagonal elements
if directed and from_region_idx == to_region_idx:
continue

Expand All @@ -1103,6 +1113,109 @@ def insort_barrier(from_region_idx,
to_region_idx,
event_list)

def annotate_paleogeo_events(self) -> None:
"""Annotate paleogeo events as (de)stabilizing or not.
After paleogeographic events have been added to the event series
of each internal node, we visit each one of them and annotate
them according to their (de)stabilizing status with respect to
the range splitting event happening at speciation. If the range
does not split at speciation, the annotation is 'N/A'.
"""

# only internal nodes in _event_series_dict
for nd_label, it_event_series_dict in self._event_series_dict.items():
for it_idx, event_series in it_event_series_dict.items():
# node at the start of process may have an empty event series
# if stoch mapping file for some reason did not have an entry
# those nodes
if isinstance(event_series, EvolRelevantEventSeries):
event_list = event_series.event_list
n_events = len(event_list)
range_split_smap = event_list[-1]
assert (isinstance(range_split_smap, pjsmap.RangeSplitOrBirth))

# gathering splitting range info
ch1_set = range_split_smap.to_state_idx_set
ch2_set = range_split_smap.to_state2_idx_set

for ev_idx, ev in enumerate(event_list):
if isinstance(ev, (pjfio.BarrierAppearance,
pjfio.BarrierDisappearance)):

# gathering paleogeo event info
#
# pair of regions dis/reconnected by barrier
# appearing or disappearing
reg1_idx = ev.from_node_idx
reg2_idx = ev.to_node_idx

# if the range splits at the internal node, then we
# need to do things
#
# grab range bit pattern at paleogeographic event
range_bp_at_paleogeo_ev = ""

for next_ev_idx in range((ev_idx + 1), n_events):
next_ev = event_list[next_ev_idx]

if isinstance(next_ev, pjsmap.StochMap):
range_bp_at_paleogeo_ev = \
next_ev.from_state_bit_patt
# now get region indices of occupied range
range_bp_at_paleogeo_set = \
set([idx for idx, b in \
enumerate(range_bp_at_paleogeo_ev) \
if b == "1"])
break

assert (range_bp_at_paleogeo_ev != "")

# annotate paleogeo event with occupied range
ev.range_bit_patt = range_bp_at_paleogeo_ev

##################################################
# Scenario 1 in which (de)stabilization does not #
# make sense: if there is no range split #
##################################################
if ch1_set == ch2_set:
if isinstance(ev, pjfio.BarrierDisappearance):
ev.restabilized_range = "N/A"

elif isinstance(ev, pjfio.BarrierAppearance):
ev.destabilized_range = "N/A"

continue

##################################################
# Scenario 2 in which (de)stabilization does not #
# make sense: when the regions (dis)reconnected #
# are not occupied by lineage #
##################################################
if reg1_idx not in range_bp_at_paleogeo_set or \
reg2_idx not in range_bp_at_paleogeo_set:
if isinstance(ev, pjfio.BarrierDisappearance):
ev.restabilized_range = "N/A"

elif isinstance(ev, pjfio.BarrierAppearance):
ev.destabilized_range = "N/A"

continue

# if paleogeo eventis split-relevant
if (reg1_idx in ch1_set and reg2_idx in ch2_set) or \
(reg1_idx in ch2_set and reg2_idx in ch1_set) and \
ch1_set != ch2_set:

# if barrier disappearing is split-relevant
# (if split is within-region, we will never get inside
# this if block)
if isinstance(ev, pjfio.BarrierDisappearance):
ev.restabilized_range = "stab"

elif isinstance(ev, pjfio.BarrierAppearance):
ev.destabilized_range = "destab"

def populate_trunc_event_series(self) -> None:
"""Populate _truncated_event_series_dict
Expand Down Expand Up @@ -1133,12 +1246,14 @@ def populate_trunc_event_series(self) -> None:
assert (isinstance(range_split_smap, pjsmap.RangeSplitOrBirth))

# gathering splitting range info
ch1_bp = range_split_smap.to_state_bit_patt
ch2_bp = range_split_smap.to_state2_bit_patt
ch1_set = set([idx for idx, b in enumerate(ch1_bp) if b == "1"])
ch2_set = ch1_set
if ch2_bp is not None:
ch2_set = set([idx for idx, b in enumerate(ch2_bp) if b == "1"])
ch1_set = range_split_smap.to_state_idx_set
ch2_set = range_split_smap.to_state2_idx_set
# ch1_bp = range_split_smap.to_state_bit_patt
# ch2_bp = range_split_smap.to_state2_bit_patt
# ch1_set = set([idx for idx, b in enumerate(ch1_bp) if b == "1"])
# ch2_set = ch1_set
# if ch2_bp is not None:
# ch2_set = set([idx for idx, b in enumerate(ch2_bp) if b == "1"])

# now preparing truncated event list
#
Expand All @@ -1163,7 +1278,7 @@ def populate_trunc_event_series(self) -> None:
############################################
if isinstance(ev, pjsmap.RangeExpansion) and \
ev.stabilized_range_wrt_split and \
ch2_bp is not None:
ch1_set != ch2_set:
truncate_here = True

#############################################
Expand All @@ -1173,18 +1288,20 @@ def populate_trunc_event_series(self) -> None:
# regions are on opposing sides of the #
# split #
#############################################
elif isinstance(ev, pjfio.BarrierDisappearance):
elif isinstance(ev, pjfio.BarrierDisappearance) and \
ev.restabilized_range == "stab" and \
ch1_set != ch2_set:
# pair of regions reconnected by barrier disappearing
reg1_idx = ev.from_node_idx
reg2_idx = ev.to_node_idx
# reg1_idx = ev.from_node_idx
# reg2_idx = ev.to_node_idx

# if barrier disappearing is split-relevant
# (if split is within-region, we will never get inside
# this if block)
if (reg1_idx in ch1_set and reg2_idx in ch2_set) or \
(reg1_idx in ch2_set and reg2_idx in ch1_set) and \
ch2_bp is not None:
truncate_here = True
# if (reg1_idx in ch1_set and reg2_idx in ch2_set) or \
# (reg1_idx in ch2_set and reg2_idx in ch1_set) and \
# ch2_bp is not None:
truncate_here = True

# actually truncate!
if truncate_here:
Expand Down
58 changes: 54 additions & 4 deletions src/phylojunction/functionality/feature_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ class BarrierAppearance(pjev.EvolRelevantEvent):
in barrier.
"""

from_node_idx: int
to_node_idx: int
destabilized_range: ty.Optional[str]
range_bit_patt: ty.Optional[str]

def __init__(self,
n_chars: int,
age: ty.Optional[float],
Expand All @@ -263,17 +268,37 @@ def __init__(self,

self.from_node_idx = from_node_idx
self.to_node_idx = to_node_idx
self.destabilized_range = None
self.range_bit_patt = None

def short_str(self) -> str:
short_str = "b+(" + str(self.age) + ")"
short_str = "b+(" + str(self.age) + ")" \
+ "_reg(" + str(self.from_node_idx) + "|" \
+ str(self.to_node_idx) + ")"

if self.destabilized_range is not None and \
self.range_bit_patt is not None:
short_str += "_" + self.destabilized_range \
+ "(" + self.range_bit_patt + ")"

return short_str

def __str__(self) -> str:
return "Barrier (at age = " + str(self.age) \
str_representation = \
"Barrier (at age = " + str(self.age) \
+ ") between region " + str(self.from_node_idx) \
+ " and " + str(self.to_node_idx)

if self.destabilized_range is not None and \
self.range_bit_patt is not None:
str1 = "\n Destabilized range (" + self.range_bit_patt + "): "
str2 = "True" if self.destabilized_range == "destab" \
else "N/A"
suffix = str1 + str2
str_representation += suffix

return str_representation

def __lt__(self, other) -> bool:
return super().__lt__(other)

Expand All @@ -288,6 +313,11 @@ class BarrierDisappearance(pjev.EvolRelevantEvent):
in barrier.
"""

from_node_idx: int
to_node_idx: int
restabilized_range: ty.Optional[str]
range_bit_patt: ty.Optional[str] # range at time of event

def __init__(self,
n_chars: int,
age: ty.Optional[float],
Expand All @@ -300,18 +330,38 @@ def __init__(self,

self.from_node_idx = from_node_idx
self.to_node_idx = to_node_idx
self.restabilized_range = None
self.range_bit_patt = None

def short_str(self) -> str:
short_str = "b-(" + str(self.age) + ")"
short_str = "b-(" + str(self.age) + ")" \
+ "_reg(" + str(self.from_node_idx) + "|" \
+ str(self.to_node_idx) + ")"

if self.restabilized_range is not None and \
self.range_bit_patt is not None:
short_str += "_" + self.restabilized_range \
+ "(" + self.range_bit_patt + ")"

return short_str

def __str__(self) -> str:
return "Connectivity restablished (at age = " \
str_representation = \
"Connectivity restablished (at age = " \
+ str(self.age) + ") between region " \
+ str(self.from_node_idx) + " and " \
+ str(self.to_node_idx)

if self.restabilized_range is not None and \
self.range_bit_patt is not None:
str1 = "\n Restabilized range (" + self.range_bit_patt + "): "
str2 = "True" if self.restabilized_range == "stab" \
else "N/A"
suffix = str1 + str2
str_representation += suffix

return str_representation

def __lt__(self, other) -> bool:
return super().__lt__(other)

Expand Down
Loading

0 comments on commit 2dd502b

Please sign in to comment.