Skip to content

Commit

Permalink
add test and refactor main function
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanMainguy committed Dec 4, 2023
1 parent 975e37b commit a76dfc2
Show file tree
Hide file tree
Showing 11 changed files with 1,240 additions and 255 deletions.
92 changes: 14 additions & 78 deletions binette/bin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,30 @@ def __and__(self, other: 'Bin') -> 'Bin':
return Bin(contigs, origin, name)


def add_length(self, length: float) -> None:
def add_length(self, length: int) -> None:
"""
Add the length attribute to the Bin object.
Add the length attribute to the Bin object if the provided length is a positive integer.
:param length: The length value to add.
:return: None
"""
self.length = length
if isinstance(length, int) and length > 0:
self.length = length
else:
raise ValueError("Length should be a positive integer.")

def add_N50(self, n50: float) -> None:
def add_N50(self, n50: int) -> None:
"""
Add the N50 attribute to the Bin object.
:param n50: The N50 value to add.
:return: None
"""
self.N50 = n50
if isinstance(n50, int) and n50 >= 0:
self.N50 = n50
else:
raise ValueError("N50 should be a positive integer.")


def add_quality(self, completeness: float, contamination: float, contamination_weight: float) -> None:
"""
Expand All @@ -106,7 +113,6 @@ def add_quality(self, completeness: float, contamination: float, contamination_w
:param completeness: The completeness value.
:param contamination: The contamination value.
:param contamination_weight: The weight assigned to contamination in the score calculation.
:return: None
"""
self.completeness = completeness
self.contamination = contamination
Expand Down Expand Up @@ -260,50 +266,6 @@ def from_bin_sets_to_bin_graph(bin_name_to_bin_set: Dict[str, set]) -> nx.Graph:
return G


def get_bin_graph(bins: List[Bin]) -> nx.Graph:
"""
Creates a bin graph from a list of Bin objects.
:param bins: A list of Bin objects representing bins.
:return: A networkx Graph representing the bin graph of overlapping bins.
"""
G = nx.Graph()
G.add_nodes_from((b.id for b in bins))

for i, (bin1, bin2) in enumerate(itertools.combinations(bins, 2)):

if bin1.overlaps_with(bin2):
# logging.info(f"{bin1} overlaps with {bin2}")
G.add_edge(
bin1.id,
bin2.id,
)
return G


def get_bin_graph_with_attributes(bins: List[Bin], contig_to_length: Dict[str, int]) -> nx.Graph:
"""
Creates a graph from a list of Bin objects with additional attributes.
:param bins: A list of Bin objects representing bins.
:param contig_to_length: A dictionary mapping contig names to their lengths.
:return: A networkx Graph representing the bin graph with attributes.
"""
G = nx.Graph()
G.add_nodes_from((b.id for b in bins))

for i, (bin1, bin2) in enumerate(itertools.combinations(bins, 2)):
if bin1.overlaps_with(bin2):

contigs = bin1.contigs & bin2.contigs
shared_length = sum((contig_to_length[c] for c in contigs))
max_shared_length_prct = 100 - 100 * (shared_length / min((bin1.length, bin2.length)))

G.add_edge(bin1.id, bin2.id, weight=max_shared_length_prct)
return G


def get_all_possible_combinations(clique: Iterable) -> Iterable[Tuple]:
"""
Expand Down Expand Up @@ -353,8 +315,7 @@ def get_difference_bins(G: nx.Graph) -> Set[Bin]:
difference_bins = set()

for clique in nx.clique.find_cliques(G):
# TODO should not use combinations but another method of itertools
# to get all possible combination in all possible order.

bins_combinations = get_all_possible_combinations(clique)
for bins in bins_combinations:

Expand All @@ -376,7 +337,7 @@ def get_union_bins(G: nx.Graph, max_conta: int = 50) -> Set[Bin]:
"""
Retrieves the union bins from a given graph.
:param G: A networkx Graph representing the graph.
:param G: A networkx Graph representing the graph of bins.
:param max_conta: Maximum allowed contamination value for a bin to be included in the union.
:return: A set of Bin objects representing the union bins.
Expand All @@ -400,31 +361,6 @@ def get_union_bins(G: nx.Graph, max_conta: int = 50) -> Set[Bin]:
return union_bins


def create_intersec_diff_bins(G: nx.Graph) -> Set[Bin]:
"""
Creates intersection and difference bins from a given graph.
:param G: A networkx Graph representing the graph.
:return: A set of Bin objects representing the intersection and difference bins.
"""
new_bins = set()

for clique in nx.clique.find_cliques(G):
bins_combinations = get_all_possible_combinations(clique)
for bins in bins_combinations:

# intersection
intersec_bin = bins[0].intersection(*bins[1:])
new_bins.add(intersec_bin)

# difference
for bin_a in bins:
bin_diff = bin_a.difference(*(b for b in bins if b != bin_a))
new_bins.add(bin_diff)

return new_bins

def select_best_bins(bins: List[Bin]) -> List[Bin]:
"""
Selects the best bins from a list of bins based on their scores, N50 values, and IDs.
Expand Down
52 changes: 4 additions & 48 deletions binette/bin_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,51 +135,6 @@ def add_bin_size_and_N50(bins: Iterable[Bin], contig_to_size: Dict[str,int]):
bin_obj.add_N50(n50)



def get_bin_size_and_N50(bin_obj, contig_to_size: Dict[str, int]):
"""
Calculate and add bin size and N50 to a bin object.
:param bin_obj: The bin object to calculate size and N50 for.
:type bin_obj: Any
:param contig_to_size: Dictionary mapping contig names to their sizes.
:type contig_to_size: Dict[str, int]
"""
lengths = [contig_to_size[c] for c in bin_obj.contigs]
n50 = compute_N50(lengths)

bin_obj.add_length(sum(lengths))
bin_obj.add_N50(n50)


def add_bin_metrics_in_parallel(bins: List, contig_info: Dict, threads: int, contamination_weight: float):
"""
Add bin metrics in parallel for a list of bins.
:param bins: List of bin objects.
:type bins: List
:param contig_info: Dictionary containing contig information.
:type contig_info: Dict
:param threads: Number of threads to use for parallel processing.
:type threads: int
:param contamination_weight: Weight for contamination assessment.
:type contamination_weight: float
:return: Set of processed bin objects.
:rtype: Set
"""
chunk_size = int(len(bins) / threads) + 1
print("CHUNK SIZE TO PARALLELIZE", chunk_size)
results = []
with cf.ProcessPoolExecutor(max_workers=threads) as tpe:
for i, bins_chunk in enumerate(chunks(bins, chunk_size)):
print(f"chunk {i}, {len(bins_chunk)} bins")
results.append(tpe.submit(add_bin_metrics, *(bins_chunk, contig_info, contamination_weight)))

processed_bins = {bin_o for r in results for bin_o in r.result()}
return processed_bins



def add_bin_metrics(bins: List, contig_info: Dict, contamination_weight: float, threads: int = 1):
"""
Add metrics to a list of bins.
Expand Down Expand Up @@ -235,7 +190,8 @@ def assess_bins_quality_by_chunk(bins: List,
contig_to_aa_length: Dict,
contamination_weight: float,
postProcessor:modelPostprocessing.modelProcessor = None,
threads: int = 1,):
threads: int = 1,
chunk_size: int = 2500):
"""
Assess the quality of bins in chunks.
Expand All @@ -249,10 +205,10 @@ def assess_bins_quality_by_chunk(bins: List,
:param contamination_weight: Weight for contamination assessment.
:param postProcessor: post-processor from checkm2
:param threads: Number of threads for parallel processing (default is 1).
:param chunk_size: The size of each chunk.
"""
n = 2500

for i, chunk_bins_iter in enumerate(chunks(bins, n)):
for i, chunk_bins_iter in enumerate(chunks(bins, chunk_size)):
chunk_bins = set(chunk_bins_iter)
logging.debug(f"chunk {i}: assessing quality of {len(chunk_bins)}")
assess_bins_quality(
Expand Down
Loading

0 comments on commit a76dfc2

Please sign in to comment.