Skip to content

Commit

Permalink
Update pysam, perform multiple alignment, activate multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
aghozlane committed Oct 4, 2024
1 parent e12e8c2 commit 10276fa
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 226 deletions.
142 changes: 87 additions & 55 deletions meteor/phylogeny.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
from collections import OrderedDict
from datetime import datetime
from typing import Iterable, Tuple
from cogent3 import load_unaligned_seqs # , load_aligned_seqs

# from cogent3.evolve.distance import EstimateDistances
# from cogent3.evolve.models import GTR
# from cogent3.cluster.UPGMA import upgma
from cogent3.align.progressive import tree_align
from cogent3 import load_unaligned_seqs, make_aligned_seqs
from cogent3.evolve.distance import EstimateDistances
from cogent3.evolve.models import GTR
from cogent3.cluster.UPGMA import upgma
from concurrent.futures import ProcessPoolExecutor, as_completed


@dataclass
Expand Down Expand Up @@ -88,9 +87,7 @@ def clean_sites(
resultdict[gene_id] = output_seq
print(flush=True, file=output)
return resultdict, info_sites
# return info_sites

# def set_tree_config(self, raxml_ng_version: str) -> dict: # pragma: no cover
def set_tree_config(self):
"""Define the census configuration
Expand All @@ -102,8 +99,6 @@ def set_tree_config(self):
"meteor_version": self.meteor.version,
"phylogeny": {
"phylogeny_tool": "cogent3",
# "phylogeny_tool": "raxml-ng",
# "phylogeny_version": raxml_ng_version,
"phylogeny_date": datetime.now().strftime("%Y-%m-%d"),
"tree_files": ",".join([tree.name for tree in self.tree_files]),
},
Expand All @@ -116,61 +111,98 @@ def remove_edge_labels(self, newick: str) -> str:
# Replace matched patterns with ":" (effectively removing the edge label)
return re.sub(pattern, ":", newick)

def process_msp_file(
self, msp_file: Path, idx: int, msp_count: int, tree_dir, tmp_dir
) -> Tuple[Path, bool]:
"""Process a single MSP file and generate its phylogeny tree."""
logging.info(
"%d/%d %s: Start analysis",
idx,
msp_count,
msp_file.name.replace(".fasta", ""),
)
tree_file = tree_dir / f"{msp_file.stem}.tree"

with NamedTemporaryFile(mode="wt", dir=tmp_dir, suffix=".fasta") as temp_clean:
# Clean sites
logging.info("Clean sites for %s", msp_file.name)
_, info_sites = self.clean_sites(msp_file, temp_clean)

if info_sites < self.min_info_sites:
logging.info(
"Only %d informative sites (< %d threshold) left after cleaning, skipping %s.",
info_sites,
self.min_info_sites,
msp_file.name.replace(".fasta", ""),
)
return tree_file, False # Return False to indicate skipping

# Perform alignments and UPGMA
logging.info("Running UPGMA and Distance Estimation")
aligned_seqs = make_aligned_seqs(
load_unaligned_seqs(temp_clean.name, moltype="dna"),
moltype="dna",
array_align=True,
)
d = EstimateDistances(aligned_seqs, submodel=GTR())
d.run(show_progress=False)

# Create UPGMA Tree
mycluster = upgma(d.get_pairwise_distances())
mycluster = mycluster.unrooted_deepcopy()

with tree_file.open("w") as f:
f.write(
self.remove_edge_labels(mycluster.get_newick(with_distances=True))
)

return tree_file, tree_file.exists()

def execute(self) -> None:
logging.info("Launch phylogeny analysis")
# Start phylogenies
start = perf_counter()

self.tree_files: list[Path] = []
msp_count = len(self.msp_file_list)
for idx, msp_file in enumerate(self.msp_file_list, start=1):
logging.info(
"%d/%d %s: Start analysis",
idx,
msp_count,
msp_file.name.replace(".fasta", ""),
)
with NamedTemporaryFile(
mode="wt", dir=self.meteor.tmp_dir, suffix=".fasta"
) as temp_clean:
tree_file = self.meteor.tree_dir / f"{msp_file.name}".replace(
".fasta", ""
)
# Clean sites
logging.info("Clean sites")
_, info_sites = self.clean_sites(msp_file, temp_clean)
if info_sites < self.min_info_sites:
logging.info(
"Only %d informative sites (< %d threshold) left after cleaning, skip.",
info_sites,
self.min_info_sites,
)
else:
seqs = load_unaligned_seqs(temp_clean.name, moltype="dna")
# params = {"kappa": 4.0}
_, tree = tree_align(
"GTR",
seqs,
# param_vals=params,
show_progress=False,
)
# print(aln)
with tree_file.with_suffix(".tree").open("w") as f:
f.write(
self.remove_edge_labels(
tree.get_newick(with_distances=True)
)
# Using ProcessPoolExecutor to parallelize the MSP file processing
with ProcessPoolExecutor(max_workers=self.meteor.threads) as executor:
futures = {
executor.submit(
self.process_msp_file,
msp_file,
idx,
msp_count,
self.meteor.tree_dir,
self.meteor.tmp_dir,
): msp_file
for idx, msp_file in enumerate(self.msp_file_list, start=1)
}

for future in as_completed(futures):
msp_file = futures[future]
try:
tree_file, success = future.result()
if success:
self.tree_files.append(tree_file)
logging.info(
"Completed MSP tree for MSP %s",
msp_file.name.replace(".fasta", ""),
)
else:
logging.info(
"Skipped MSP %s due to insufficient informative sites",
msp_file.name.replace(".fasta", ""),
)
if tree_file.with_suffix(".tree").exists():
self.tree_files.append(tree_file.with_suffix(".tree"))
logging.info(
"Completed MSP tree for MSP %s",
msp_file.name.replace(".fasta", ""),
except Exception as exc:
logging.error(
"MSP %s generated an exception: %s", msp_file.name, exc
)
else:
logging.info("No tree file generated")

logging.info("Completed phylogeny in %f seconds", perf_counter() - start)
logging.info(
"Trees were generated for %d/%d MSPs", len(self.tree_files), msp_count
)

# Save configuration after all trees are processed
config = self.set_tree_config()
self.save_config(config, self.meteor.tree_dir / "census_stage_4.json")
2 changes: 1 addition & 1 deletion meteor/strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def execute(self) -> None:
)
sys.exit(1)
try:
start = perf_counter()
census_json = self.get_census_stage(self.meteor.mapped_sample_dir, 1)
sample_info = census_json["sample_info"]
stage3_dir = self.meteor.strain_dir / sample_info["sample_name"]
Expand Down Expand Up @@ -315,7 +316,6 @@ def execute(self) -> None:
/ self.json_data["reference"]["reference_file"]["database_dir"]
/ self.json_data["reference"]["annotation"]["bed"]["filename"]
)
start = perf_counter()
# count_file,
self.get_msp_variant(
consensus_file, msp_file, cram_file, bed_file, reference_file
Expand Down
2 changes: 1 addition & 1 deletion meteor/tests/test_strain.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,4 @@ def test_execute(strain_builder, tmp_path: Path) -> None:
BS = tmp_path / "strain" / "test" / "BS.fasta.xz"
assert BS.exists()
with BS.open("rb") as out:
assert md5(out.read()).hexdigest() == "c4a414c7677da877a6b0a569f8950cda"
assert md5(out.read()).hexdigest() == "665997d7dc24653bc001c2789fecb8fb"
Loading

0 comments on commit 10276fa

Please sign in to comment.