Skip to content

Commit

Permalink
Refactor GUI a bit
Browse files Browse the repository at this point in the history
Cleaned up a few more sse unit tests
  • Loading branch information
binho authored and binho committed Dec 13, 2023
1 parent ef9eb2d commit f3a517c
Show file tree
Hide file tree
Showing 6 changed files with 569 additions and 282 deletions.
177 changes: 128 additions & 49 deletions src/phylojunction/distribution/dn_discrete_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def _check_input_health(self) -> None:
######################

def _get_next_event_time(self, total_rate: float) -> float:
"""Draw next exponentially distributed event time
"""Draw next exponentially distributed event time.
Args:
current_node_target_count (int): Number of nodes that
Expand All @@ -405,9 +405,10 @@ def _get_next_event_time(self, total_rate: float) -> float:
"""

next_time = \
float(dnpar.DnExponential.draw_exp(1,
total_rate,
rate_parameterization=True))
float(
dnpar.DnExponential.draw_exp(1,
total_rate,
rate_parameterization=True))

return next_time

Expand All @@ -420,29 +421,64 @@ def _execute_birth(
state_transition_dict: ty.Dict[str, ty.List[AttributeTransition]],
untargetable_node_set: ty.Set[str],
cumulative_node_count: int,
macroevol_atomic_param: sseobj.DiscreteStateDependentRate,
sse_birth_rate_object: sseobj.DiscreteStateDependentRate,
event_t: float,
debug=False) -> ty.Tuple[dp.Node, int]:
"""Execute lineage birth (side-effect and return)
"""Execute lineage birth.
This method has both side-effects and a return.
The side-effects include the updating of class members. These
members are passed as arguments in the signature so that we
can have a better idea of what to expect as behavior:
(i) self.tr.tr_namespace,
(ii) self.state_representation_dict
(iii) self.sa_lineage_dict
(iv) self.state_transition_dict
(v) self.root_is_born
(vi) self.untargetable_node_set
(Note that Python passes arguments 'by assignment',
which means that for mutable arguments, we are doing something
akin to 'pass by reference'; we can mutate -- but not
reassign! -- the argument variable and changes will be
reflected outside the function.)
Args:
tr_namespace (dendropy.TaxonNamespace): Dendropy object recording taxa in the tree.
chosen_node (dendropy.Node): Node that will undergo speciation.
state_representation_dict (dict): Dictionary that keeps track of all states currently represented by living lineages.
sa_lineage_dict (dict): Dictionary that keeps track of nodes that have sampled ancestor children..
state_transition_dict (dict): Dictionary that keeps track of nodes subtending which state transitions happen (used for plotting)
untargetable_node_set (set): Set of Node labels that cannot be targeted for events anymore (went extinct).
cumulative_node_count (int): Total number of nodes in the tree (to be used in labeling).
macroevol_atomic_param (AtomicRateParameter): Atomic rate parameter containing departing/arriving state.
tr_namespace (dendropy.TaxonNamespace): Dendropy object
recording taxa in the tree.
chosen_node (dendropy.Node): Node that will undergo
birth event.
state_representation_dict (dict): Dictionary that keeps
track of all states currently represented by living
lineages. Keys are integer representing the states,
values are sets of taxon names (strings).
sa_lineage_dict (dict): Dictionary that keeps track of
nodes that have direct (sampled) ancestor children.
Keys are taxon names (strings), values are lists
of SampledAncestor objects.
state_transition_dict (dict): Dictionary that keeps track
of nodes subtending which state transitions happen
(used for plotting). Keys are taxon names (strings)
and values are lists of AttributeTransition objects.
untargetable_node_set (set): Set of Node labels that
cannot be targeted for events anymore (went extinct).
cumulative_node_count (int): Total number of nodes in the
tree (to be used in labeling).
sse_birth_rate_object (DiscreteStateDependentRate): SSE
rate parameter object holding the departing/arriving
states.
event_t (float): Time of birth event taking place.
debug (bool): If 'true', prints debugging messages. Defaults to False.
debug (bool): Flag for printing debugging messages.
If 'True', prints messages. Defaults to 'False'.
Returns:
(dendropy.Node, int): Tuple with last node to under go event and total (cumulative) node count.
(tuple): Tuple with two objects, the last node to speciate
(dendropy.Node) and the tree's cumulative node count (int).
"""

left_arriving_state, right_arriving_state = \
macroevol_atomic_param.arriving_states
sse_birth_rate_object.arriving_states

if debug:
print("> SPECIATION of node " + chosen_node.label
Expand All @@ -469,7 +505,6 @@ def _execute_birth(
root_node.is_sa_lineage = chosen_node.is_sa_lineage
tr_namespace.add_taxon(root_node.taxon)
self.root_is_born = True
# state_representation_dict[root_node.state].add(root_node.label)

# origin/brosc cannot be selected anymore
chosen_node.alive = False # origin/brosc node is no longer alive
Expand Down Expand Up @@ -505,10 +540,12 @@ def _execute_birth(
# and
# [ori] ---> [root], respectively
elif chosen_node.label == "brosc":
# must add the evolution leading up to the brosc_node to the root node edge length
# (note that the brosc_node edge length will always be 0.0 if it resulted from an
# ancestor sampling event, but it could be > 0.0 if a state transition event happened)
# root_node.edge_length += chosen_node.edge_length # I think this is wrong,
# must add the evolution leading up to the brosc_node to the
# root node edge length
#
# (note that the brosc_node edge length will always be 0.0 if
# it resulted from an ancestor sampling event, but it could
# be > 0.0 if a state transition event happened)

root_node.edge_length = chosen_node.edge_length

Expand Down Expand Up @@ -637,7 +674,11 @@ def _execute_birth(

# (4) if chosen node was on a lineage with SAs, we update the SAs info
if chosen_node.is_sa_lineage:
self._update_sa_lineage_dict(event_t, sa_lineage_dict, [chosen_node.label], debug=debug)
self._update_sa_lineage_dict(
event_t,
sa_lineage_dict,
[chosen_node.label],
debug=debug)

return last_node2speciate, cumulative_node_count

Expand All @@ -650,17 +691,47 @@ def _execute_death(
untargetable_node_set: ty.Set[dp.Node],
event_t: float,
debug=False) -> dp.Node:
"""Execute lineage death (side-effect and return)
"""Execute lineage death.
This method has both side-effects and a return.
The side-effects include the updating of class members. These
members are passed as arguments in the signature so that we
can have a better idea of what to expect as behavior:
(i) self.tr.tr_namespace,
(ii) self.state_representation_dict
(iii) self.sa_lineage_dict
(iv) self.untargetable_node_set
(Note that Python passes arguments 'by assignment',
which means that for mutable arguments, we are doing something
akin to 'pass by reference'; we can mutate -- but not
reassign! -- the argument variable and changes will be
reflected outside the function.)
Args:
tr_namespace (dendropy.TaxonNamespace): Dendropy object recording taxa in the tree.
chosen_node (dendropy.Node): Node that will undergo extinction.
state_representation_dict (dict): Dictionary that keeps track of all states currently represented by living lineages.
sa_lineage_dict (dict): Dictionary that keeps track of nodes that have sampled ancestor children..
untargetable_node_set (set): Set of Node labels that cannot be targeted for events anymore (went extinct).
cumulative_node_count (int): Total number of nodes in the tree (to be used in labeling).
tr_namespace (dendropy.TaxonNamespace): Dendropy object
recording taxa in the tree.
chosen_node (dendropy.Node): Node that will undergo
extinction.
state_representation_dict (dict): Dictionary that keeps
track of all states currently represented by living
lineages. Keys are integer representing the states,
values are sets of taxon names (strings).
sa_lineage_dict (dict): Dictionary that keeps track of
nodes that have direct (sampled) ancestor children.
Keys are taxon names (strings), values are lists
of SampledAncestor objects.
untargetable_node_set (set): Set of Node labels that
cannot be targeted for events anymore (went extinct).
cumulative_node_count (int): Total number of nodes in the
tree (to be used in labeling).
event_t (float): Time of death event taking place.
debug (bool): If 'true', prints debugging messages. Defaults to False.
debug (bool): Flag for printing debugging messages.
If 'True', prints messages. Defaults to 'False'.
Returns:
(dendropy.Node) Last node to die.
"""

if debug:
Expand All @@ -679,20 +750,26 @@ def _execute_death(
# if chosen node was on a lineage with SAs,
# we update the SAs info
if chosen_node.is_sa_lineage:
self._update_sa_lineage_dict(event_t, sa_lineage_dict, [chosen_node.label])
self._update_sa_lineage_dict(event_t,
sa_lineage_dict,
[chosen_node.label])

######################################
# Special case: origin went extinct, #
# we slap a brosc node #
######################################
if chosen_node.label == "origin":
# at this point, the origin was chosen to die, but this node will never be extended
# (the origin always has an origin_edge_length = 0.0); for us to account for the
# evolution (branch length) that has happened before this death -- between the origin
# and the brosc_node being added -- we must add event_t as the brosc_node edge_length
# (other nodes are always added with edge_length = 0.0, and have their edges extended
# at this point, the origin was chosen to die, but this node will
# never be extended (the origin always has an
# origin_edge_length = 0.0); for us to account for the evolution
# (branch length) that has happened before this death -- between
# the origin and the brosc_node being added -- we must add
# 'event_t' as the brosc_node edge_length (other nodes are always
# added with edge_length = 0.0, and have their edges extended
# when a new event takes place)
brosc_node = dp.Node(taxon=dp.Taxon(label="brosc"), label="brosc", edge_length=event_t)
brosc_node = dp.Node(taxon=dp.Taxon(label="brosc"),
label="brosc",
edge_length=event_t)
brosc_node.alive = False
brosc_node.is_sa = False
brosc_node.is_sa_dummy_parent = False
Expand Down Expand Up @@ -740,7 +817,10 @@ def _execute_anatrans(
"""

if debug:
print("TRANSITION of node " + chosen_node.label + " from state " + str(chosen_node.state) + " to state " + str(macroevol_rate_param.arriving_states[0]))
print("TRANSITION of node " + chosen_node.label \
+ " from state " + str(chosen_node.state) \
+ " to state " \
+ str(macroevol_rate_param.arriving_states[0]))

# new state
arriving_state = macroevol_rate_param.arriving_states[0]
Expand Down Expand Up @@ -1329,7 +1409,6 @@ def _extend_all_living_nodes(branch_length, end=False):
# through the "age stop condition"
#
# next_max_t will be None if self.slice_t_ends is empty
# print("dn_discrete_sse.py: at (4)")
excess_t = 0.0
if self.stop == "age" and self.n_time_slices > 1 and latest_t > next_max_t:
excess_t = latest_t - next_max_t
Expand Down Expand Up @@ -1383,16 +1462,20 @@ def _extend_all_living_nodes(branch_length, end=False):
# with a single time slice)
if (self.stop == "age" and (latest_t > t_stop)) or \
latest_t < 0.0:
_extend_all_living_nodes(t_stop - (latest_t - t_to_next_event), end=True)
_extend_all_living_nodes(
t_stop - (latest_t - t_to_next_event),
end=True)

sa_lineage_node_labels = [nd.label for nd in living_nodes if nd.is_sa_lineage]
sa_lineage_node_labels = \
[nd.label for nd in living_nodes if nd.is_sa_lineage]
# updates SA info for plotting
self._update_sa_lineage_dict(
t_stop,
sa_lineage_dict,
sa_lineage_node_labels)

# if origin is the only node (root always has children), we slap brosc node at end of process
# if origin is the only node (root always has children)
# we slap brosc node at end of process
if self.with_origin and tr.seed_node.alive and \
len(tr.seed_node.child_nodes()) == 0:
brosc_node.edge_length = t_stop
Expand All @@ -1419,7 +1502,6 @@ def _extend_all_living_nodes(branch_length, end=False):
# (6) draw a node we'll apply the event to
#
# a lineage will be chosen in proportion to the total rate of its state
# print("dn_discrete_sse.py: at (5.1)")
lineage_weights = \
[state_total_rates[nd.state] for nd in living_nodes]
chosen_node = \
Expand Down Expand Up @@ -1457,10 +1539,8 @@ def _extend_all_living_nodes(branch_length, end=False):
# are called, so we do not need to traverse the tree and
# update living_nodes all the time (below, in step (9))
# print("dn_discrete_sse.py: at (5.3)")
last_chosen_node, \
cumulative_node_count, \
cumulative_sa_count = \
self._execute_event(
last_chosen_node, cumulative_node_count, cumulative_sa_count \
= self._execute_event(
tr.taxon_namespace,
macroevol_atomic_param,
chosen_node,
Expand All @@ -1481,7 +1561,6 @@ def _extend_all_living_nodes(branch_length, end=False):
# TODO: at some point, make 'living_nodes' a set that is passed to all execute_birth
# and execute_death functions so that we don't need to traverse the tree every event
# to list living nodes that can be targeted
# print("dn_discrete_sse.py: at (5.4)")
living_nodes = [nd for nd in tr if nd.alive]
current_node_target_count = len(living_nodes)

Expand Down
12 changes: 6 additions & 6 deletions src/phylojunction/distribution/dn_discrete_sse.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ class DnSSE(pgm.DistributionPGM):
debug: bool
def __init__(self,
sse_stash: sseobj.SSEStash = ...,
stop_value: ty.List[float] = ...,
n: int = ...,
n_replicates: int = ...,
stop: ty.Optional[str] = ...,
origin: bool = ...,
start_states_list: ty.List[int] = ...,
stop: str = "",
stop_value: ty.List[float] = [],
condition_on_speciation: bool = ...,
condition_on_survival: bool = ...,
condition_on_obs_both_sides_root: bool = ...,
min_rec_taxa: int = ...,
max_rec_taxa: int = ...,
abort_at_obs: int = ...,
seeds_list: ty.Optional[ty.List[int]] = ...,
epsilon: float = ...,
runtime_limit: int = ...,
debug: bool = ...) -> None: ...
epsilon: float = 1e-12,
runtime_limit: int = 5,
rng_seed: ty.Optional[int] = None,
debug: ty.Optional[bool] = False) -> None: ...
def _initialize_missing_prob_handler(self) -> None: ...
def _check_sample_size(self) -> None: ...
def get_next_event_time(self, total_rate: float, a_seed: ty.Optional[int] = ...) -> float: ...
Expand Down
Loading

0 comments on commit f3a517c

Please sign in to comment.