diff --git a/phylogeny.py b/phylogeny.py index 1da364e..8226e82 100644 --- a/phylogeny.py +++ b/phylogeny.py @@ -265,17 +265,19 @@ def get_descendants(bb): assert len(b.clades)==2 self.B2 = get_descendants(b.clades[0]) self.B3 = get_descendants(b.clades[1]) - self.B1 = set(i for i in range(n) if i not in self.B2 and i not in self.B3) + self.B1 = set(i for i in range(n) if i not in (self.B2 | self.B3)) def __str__(self): return f'{self.a}->{self.b} {self.B1} {self.B2} {self.B3}' class PreparedTree: - def __init__(self,T,n,index): + def __init__(self,T,n,index,forward=True): self.n = n self.T = read(StringIO(T), 'newick') self.number_nodes() - self.edges,self.expanded_nodes = self.create_edges() + self.internal_nodes = self.create_internal_nodes(forward=forward) + self.expanded_nodes = self.create_expanded_nodes(forward=forward) + self.edges = self.create_edges(forward=forward) def number_nodes(self): m = 0 @@ -286,20 +288,40 @@ def number_nodes(self): clade.name = n+m m+=1 - def create_edges(self): + def create_internal_nodes(self,forward=True): + product = [clade for clade in self.T.find_clades(order='postorder',terminal=False)] + return product if forward else product[::-1] + + def create_expanded_nodes(self,forward=True): + product = {clade.name : self.get_leaves(clade) for clade in self.T.find_clades(order='postorder',terminal=False)} + if forward: + for clade in self.internal_nodes: + for child in clade.clades: + if child.name in product: + product[clade.name] |= product[child.name] + else: + parent = {} + for clade in self.internal_nodes: + for child in clade.clades: + if child.name>= self.n: + parent[child.name] = clade.name + for clade in self.internal_nodes: + if clade.name in parent: + product[clade.name] |= product[parent[clade.name]] + + return product + + def get_leaves(self,clade): + return set(child.name for child in clade.clades if child.name=self.n: + edges.append(Edge(internal_node,child, self.expanded_nodes, self.n )) - return edges,expanded_nodes + return edges def get_count_for_pair(e1,e2): F1 = e1.B1 @@ -308,15 +330,13 @@ def get_count_for_pair(e1,e2): G1 = e2.B1 G2 = e2.B2 G3 = e2.B3 - return comb(len(F1.intersection(G1)),2,exact=True) * \ - ( len(F2.intersection(G2))*len(F3.intersection(G3))+ \ - len(F2.intersection(G3))*len(F3.intersection(G2))) + return comb(len(F1&G1),2,exact=True) * (len(F2&G2)*len(F3&G3)+ len(F2&G3)*len(F3&G2)) n = len(species) index = {species[i]:i for i in range(n)} tree1 = PreparedTree(T1,n,index) - tree2 = PreparedTree(T2,n,index) - return 2*comb(n,2) - 2*sum(get_count_for_pair (e1,e2) for e1 in tree1.edges for e2 in tree2.edges) + tree2 = PreparedTree(T2,n,index,forward=False) + return 2*comb(n,4) - 2*sum(get_count_for_pair (e1,e2) for e1 in tree1.edges for e2 in tree2.edges) @@ -826,11 +846,12 @@ def is_trivial(self): def is_singleton(self): return len(self.taxa)==1 - # newick - # - # Convert to string in Newick format - def newick(self): + ''' + newick + + Convert to string in Newick format + ''' def conv(taxon): if type(taxon)==int: return species[taxon] @@ -841,14 +862,17 @@ def conv(taxon): else: return '(' + ','.join(conv(taxon) for taxon in self.taxa) +')' - # split - # - # Split clade in two using character: list of taxa is replaced by two clades - # - # Returns True if clade has been split into two non-trivial clades - # False if at least one clade would be trivial--in which case clade is unchanged - # + def split(self,character): + ''' + split + + Split clade in two using character: list of taxa is replaced by two clades + + Returns True if clade has been split into two non-trivial clades + False if at least one clade would be trivial--in which case clade is unchanged + + ''' left = [] right = [] for i in self.taxa: @@ -863,10 +887,13 @@ def split(self,character): self.taxa = [leftTaxon,rightTaxon] return True - # splitAll - # - # Split clade using character table + def splitAll(self,characters,depth=0): + ''' + splitAll + + Split clade using character table + ''' if depth