Skip to content

Commit

Permalink
#46 Sample test passes using Sand et al
Browse files Browse the repository at this point in the history
  • Loading branch information
weka511 committed Apr 23, 2023
1 parent 9d1ca9c commit 3e17ea4
Showing 1 changed file with 63 additions and 38 deletions.
101 changes: 63 additions & 38 deletions phylogeny.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,19 @@ def get_descendants(bb):
assert len(b.clades)==2
self.B2 = get_descendants(b.clades[0])
self.B3 = get_descendants(b.clades[1])
self.B1 = set(i for i in range(n) if i not in self.B2 and i not in self.B3)
self.B1 = set(i for i in range(n) if i not in (self.B2 | self.B3))

def __str__(self):
return f'{self.a}->{self.b} {self.B1} {self.B2} {self.B3}'

class PreparedTree:
def __init__(self,T,n,index):
def __init__(self,T,n,index,forward=True):
self.n = n
self.T = read(StringIO(T), 'newick')
self.number_nodes()
self.edges,self.expanded_nodes = self.create_edges()
self.internal_nodes = self.create_internal_nodes(forward=forward)
self.expanded_nodes = self.create_expanded_nodes(forward=forward)
self.edges = self.create_edges(forward=forward)

def number_nodes(self):
m = 0
Expand All @@ -286,20 +288,40 @@ def number_nodes(self):
clade.name = n+m
m+=1

def create_edges(self):
def create_internal_nodes(self,forward=True):
product = [clade for clade in self.T.find_clades(order='postorder',terminal=False)]
return product if forward else product[::-1]

def create_expanded_nodes(self,forward=True):
product = {clade.name : self.get_leaves(clade) for clade in self.T.find_clades(order='postorder',terminal=False)}
if forward:
for clade in self.internal_nodes:
for child in clade.clades:
if child.name in product:
product[clade.name] |= product[child.name]
else:
parent = {}
for clade in self.internal_nodes:
for child in clade.clades:
if child.name>= self.n:
parent[child.name] = clade.name
for clade in self.internal_nodes:
if clade.name in parent:
product[clade.name] |= product[parent[clade.name]]

return product

def get_leaves(self,clade):
return set(child.name for child in clade.clades if child.name<self.n)

def create_edges(self,forward=True):
edges = []
expanded_nodes = {}
for internal_node in self.T.find_clades(order='postorder',terminal=False):
expanded_nodes[internal_node.name] = []
for child in internal_node.clades:
if child.name<self.n:
expanded_nodes[internal_node.name].append(child.name)
else:
edges.append(Edge(internal_node,child, expanded_nodes, self.n ))
for descendant in expanded_nodes[child.name]:
expanded_nodes[internal_node.name].append(descendant)
if child.name>=self.n:
edges.append(Edge(internal_node,child, self.expanded_nodes, self.n ))

return edges,expanded_nodes
return edges

def get_count_for_pair(e1,e2):
F1 = e1.B1
Expand All @@ -308,15 +330,13 @@ def get_count_for_pair(e1,e2):
G1 = e2.B1
G2 = e2.B2
G3 = e2.B3
return comb(len(F1.intersection(G1)),2,exact=True) * \
( len(F2.intersection(G2))*len(F3.intersection(G3))+ \
len(F2.intersection(G3))*len(F3.intersection(G2)))
return comb(len(F1&G1),2,exact=True) * (len(F2&G2)*len(F3&G3)+ len(F2&G3)*len(F3&G2))

n = len(species)
index = {species[i]:i for i in range(n)}
tree1 = PreparedTree(T1,n,index)
tree2 = PreparedTree(T2,n,index)
return 2*comb(n,2) - 2*sum(get_count_for_pair (e1,e2) for e1 in tree1.edges for e2 in tree2.edges)
tree2 = PreparedTree(T2,n,index,forward=False)
return 2*comb(n,4) - 2*sum(get_count_for_pair (e1,e2) for e1 in tree1.edges for e2 in tree2.edges)



Expand Down Expand Up @@ -826,11 +846,12 @@ def is_trivial(self):
def is_singleton(self):
return len(self.taxa)==1

# newick
#
# Convert to string in Newick format

def newick(self):
'''
newick
Convert to string in Newick format
'''
def conv(taxon):
if type(taxon)==int:
return species[taxon]
Expand All @@ -841,14 +862,17 @@ def conv(taxon):
else:
return '(' + ','.join(conv(taxon) for taxon in self.taxa) +')'

# split
#
# Split clade in two using character: list of taxa is replaced by two clades
#
# Returns True if clade has been split into two non-trivial clades
# False if at least one clade would be trivial--in which case clade is unchanged
#

def split(self,character):
'''
split
Split clade in two using character: list of taxa is replaced by two clades
Returns True if clade has been split into two non-trivial clades
False if at least one clade would be trivial--in which case clade is unchanged
'''
left = []
right = []
for i in self.taxa:
Expand All @@ -863,21 +887,24 @@ def split(self,character):
self.taxa = [leftTaxon,rightTaxon]
return True

# splitAll
#
# Split clade using character table

def splitAll(self,characters,depth=0):
'''
splitAll
Split clade using character table
'''
if depth<len(characters):
if self.split(characters[depth]):
for taxon in self.taxa:
taxon.splitAll(characters,depth+1)
else:
self.splitAll(characters,depth+1)


# Calculate entropy of a single character

def get_entropy(freq):
'''
Calculate entropy of a single character
'''
if freq==0 or freq==n: return 0
p1 = freq/n
p2 = 1-p1
Expand Down Expand Up @@ -1031,7 +1058,7 @@ def conflicts_with(c1, c2):
return True

n = len(table)
Conflicts = [0 for _ in range(n)] # Count number of times each row conflicts with another
Conflicts = np.zeros(n) # Count number of times each row conflicts with another
for i in range(n):
for j in range(i+1,n):
if conflicts_with(table[i],table[j]):
Expand All @@ -1040,8 +1067,6 @@ def conflicts_with(c1, c2):

return [table[row] for row in range(n) if row!=np.argmax(Conflicts)]



def cntq(n,newick):
'''CNTQ Counting Quartets'''
def create_adj(tree):
Expand Down

0 comments on commit 3e17ea4

Please sign in to comment.