-
Notifications
You must be signed in to change notification settings - Fork 712
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Flags for MSA search only and parallel MSA search
- Loading branch information
Showing
3 changed files
with
381 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
easybuild/easyconfigs/a/AlphaFold/AlphaFold-2.3.1_add-run_features_only-option.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
Adds a flat `--run_feature_only` and some logic that means that if | ||
`features.pkl` already exists, this step will be skipped. | ||
|
||
To a large degree taken from https://github.com/Zuricho/ParallelFold/ | ||
|
||
diff --git a/run_alphafold.py b/run_alphafold.py | ||
index 0d89bfb..04016c5 100644 | ||
--- a/run_alphafold.py | ||
+++ b/run_alphafold.py | ||
@@ -141,6 +141,8 @@ flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. ' | ||
'Relax on GPU can be much faster than CPU, so it is ' | ||
'recommended to enable if possible. GPUs must be available' | ||
' if this setting is enabled.') | ||
+flags.DEFINE_boolean('run_feature_only', False, 'Calculate MSA and template to generate ' | ||
+ 'feature and then stop.') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
@@ -180,7 +182,8 @@ def predict_structure( | ||
amber_relaxer: relax.AmberRelaxation, | ||
benchmark: bool, | ||
random_seed: int, | ||
- models_to_relax: ModelsToRelax): | ||
+ models_to_relax: ModelsToRelax, | ||
+ run_feature_only: bool): | ||
"""Predicts structure using AlphaFold for the given sequence.""" | ||
logging.info('Predicting %s', fasta_name) | ||
timings = {} | ||
@@ -193,16 +196,27 @@ def predict_structure( | ||
|
||
# Get features. | ||
t_0 = time.time() | ||
- feature_dict = data_pipeline.process( | ||
- input_fasta_path=fasta_path, | ||
- msa_output_dir=msa_output_dir) | ||
- timings['features'] = time.time() - t_0 | ||
+ features_output_path = os.path.join(output_dir, 'features.pkl') | ||
+ | ||
+ # If we already have feature.pkl file, skip the MSA and template finding step | ||
+ if os.path.exists(features_output_path): | ||
+ feature_dict = pickle.load(open(features_output_path, 'rb')) | ||
+ | ||
+ else: | ||
+ feature_dict = data_pipeline.process( | ||
+ input_fasta_path=fasta_path, | ||
+ msa_output_dir=msa_output_dir) | ||
|
||
# Write out features as a pickled dictionary. | ||
features_output_path = os.path.join(output_dir, 'features.pkl') | ||
with open(features_output_path, 'wb') as f: | ||
pickle.dump(feature_dict, f, protocol=4) | ||
|
||
+ timings['features'] = time.time() - t_0 | ||
+ | ||
+ if run_feature_only: # if not run_feature, skip the rest of the function | ||
+ return 0 | ||
+ | ||
unrelaxed_pdbs = {} | ||
unrelaxed_proteins = {} | ||
relaxed_pdbs = {} | ||
@@ -449,7 +463,8 @@ def main(argv): | ||
amber_relaxer=amber_relaxer, | ||
benchmark=FLAGS.benchmark, | ||
random_seed=random_seed, | ||
- models_to_relax=FLAGS.models_to_relax) | ||
+ models_to_relax=FLAGS.models_to_relax, | ||
+ run_feature_only=FLAGS.run_feature_only) | ||
|
||
|
||
if __name__ == '__main__': |
305 changes: 305 additions & 0 deletions
305
easybuild/easyconfigs/a/AlphaFold/AlphaFold-2.3.1_parallel-execution-of-MSA-tools.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,305 @@ | ||
MSA search is done by three tools, this adds an option to do these in | ||
parallel. From https://github.com/google-deepmind/alphafold/pull/399 | ||
|
||
From 5e18b3ab5cf9b6e20ec3bd8f0f2a0c02ba9db837 Mon Sep 17 00:00:00 2001 | ||
From: fuji <[email protected]> | ||
Date: Tue, 22 Feb 2022 10:16:10 +0900 | ||
Subject: [PATCH 1/4] Add parallel runs of MSA tools | ||
|
||
--- | ||
alphafold/data/pipeline.py | 110 ++++++++++++++++++++++--------------- | ||
run_alphafold.py | 5 +- | ||
2 files changed, 70 insertions(+), 45 deletions(-) | ||
|
||
diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py | ||
index c92944fe3..3a3a18d4a 100644 | ||
--- a/alphafold/data/pipeline.py | ||
+++ b/alphafold/data/pipeline.py | ||
@@ -27,6 +27,8 @@ | ||
from alphafold.data.tools import jackhmmer | ||
import numpy as np | ||
|
||
+import concurrent.futures | ||
+ | ||
# Internal import (7716). | ||
|
||
FeatureDict = MutableMapping[str, np.ndarray] | ||
@@ -124,7 +126,8 @@ def __init__(self, | ||
use_small_bfd: bool, | ||
mgnify_max_hits: int = 501, | ||
uniref_max_hits: int = 10000, | ||
- use_precomputed_msas: bool = False): | ||
+ use_precomputed_msas: bool = False, | ||
+ n_parallel_msa: int = 1): | ||
"""Initializes the data pipeline.""" | ||
self._use_small_bfd = use_small_bfd | ||
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( | ||
@@ -146,35 +149,13 @@ def __init__(self, | ||
self.mgnify_max_hits = mgnify_max_hits | ||
self.uniref_max_hits = uniref_max_hits | ||
self.use_precomputed_msas = use_precomputed_msas | ||
+ self.n_parallel_msa = n_parallel_msa | ||
|
||
- def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: | ||
- """Runs alignment tools on the input sequence and creates features.""" | ||
- with open(input_fasta_path) as f: | ||
- input_fasta_str = f.read() | ||
- input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) | ||
- if len(input_seqs) != 1: | ||
- raise ValueError( | ||
- f'More than one input sequence found in {input_fasta_path}.') | ||
- input_sequence = input_seqs[0] | ||
- input_description = input_descs[0] | ||
- num_res = len(input_sequence) | ||
- | ||
+ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence): | ||
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') | ||
jackhmmer_uniref90_result = run_msa_tool( | ||
- msa_runner=self.jackhmmer_uniref90_runner, | ||
- input_fasta_path=input_fasta_path, | ||
- msa_out_path=uniref90_out_path, | ||
- msa_format='sto', | ||
- use_precomputed_msas=self.use_precomputed_msas, | ||
- max_sto_sequences=self.uniref_max_hits) | ||
- mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') | ||
- jackhmmer_mgnify_result = run_msa_tool( | ||
- msa_runner=self.jackhmmer_mgnify_runner, | ||
- input_fasta_path=input_fasta_path, | ||
- msa_out_path=mgnify_out_path, | ||
- msa_format='sto', | ||
- use_precomputed_msas=self.use_precomputed_msas, | ||
- max_sto_sequences=self.mgnify_max_hits) | ||
+ self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, | ||
+ 'sto', self.use_precomputed_msas) | ||
|
||
msa_for_templates = jackhmmer_uniref90_result['sto'] | ||
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) | ||
@@ -196,29 +177,70 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: | ||
f.write(pdb_templates_result) | ||
|
||
uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto']) | ||
- mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | ||
+ uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) | ||
|
||
pdb_template_hits = self.template_searcher.get_template_hits( | ||
output_string=pdb_templates_result, input_sequence=input_sequence) | ||
+ return uniref90_msa, pdb_template_hits | ||
+ | ||
+ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): | ||
+ mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') | ||
+ jackhmmer_mgnify_result = run_msa_tool( | ||
+ self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', | ||
+ self.use_precomputed_msas) | ||
+ | ||
+ mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | ||
+ mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) | ||
+ return mgnify_msa | ||
+ | ||
+ def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): | ||
+ bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') | ||
+ hhblits_bfd_uniclust_result = run_msa_tool( | ||
+ self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path, | ||
+ 'a3m', self.use_precomputed_msas) | ||
+ bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) | ||
+ return bfd_msa | ||
+ | ||
+ def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir): | ||
+ bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') | ||
+ jackhmmer_small_bfd_result = run_msa_tool( | ||
+ msa_runner=self.jackhmmer_small_bfd_runner, | ||
+ input_fasta_path=input_fasta_path, | ||
+ msa_out_path=bfd_out_path, | ||
+ msa_format='sto', | ||
+ use_precomputed_msas=self.use_precomputed_msas) | ||
+ bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) | ||
+ return bfd_msa | ||
+ | ||
+ | ||
+ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: | ||
+ """Runs alignment tools on the input sequence and creates features.""" | ||
+ with open(input_fasta_path) as f: | ||
+ input_fasta_str = f.read() | ||
+ input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) | ||
+ if len(input_seqs) != 1: | ||
+ raise ValueError( | ||
+ f'More than one input sequence found in {input_fasta_path}.') | ||
+ input_sequence = input_seqs[0] | ||
+ input_description = input_descs[0] | ||
+ num_res = len(input_sequence) | ||
|
||
+ | ||
+ futures = [] | ||
if self._use_small_bfd: | ||
- bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') | ||
- jackhmmer_small_bfd_result = run_msa_tool( | ||
- msa_runner=self.jackhmmer_small_bfd_runner, | ||
- input_fasta_path=input_fasta_path, | ||
- msa_out_path=bfd_out_path, | ||
- msa_format='sto', | ||
- use_precomputed_msas=self.use_precomputed_msas) | ||
- bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto']) | ||
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: | ||
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) | ||
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) | ||
+ futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir)) | ||
else: | ||
- bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') | ||
- hhblits_bfd_uniclust_result = run_msa_tool( | ||
- msa_runner=self.hhblits_bfd_uniclust_runner, | ||
- input_fasta_path=input_fasta_path, | ||
- msa_out_path=bfd_out_path, | ||
- msa_format='a3m', | ||
- use_precomputed_msas=self.use_precomputed_msas) | ||
- bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) | ||
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: | ||
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) | ||
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) | ||
+ futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) | ||
+ | ||
+ uniref90_msa, pdb_template_hits = futures[0].result() | ||
+ mgnify_msa = futures[1].result() | ||
+ bfd_msa = futures[2].result() | ||
|
||
templates_result = self.template_featurizer.get_templates( | ||
query_sequence=input_sequence, | ||
diff --git a/run_alphafold.py b/run_alphafold.py | ||
index d87e00e4b..6967d0c9c 100644 | ||
--- a/run_alphafold.py | ||
+++ b/run_alphafold.py | ||
@@ -129,6 +129,8 @@ | ||
'recommended to enable if possible. GPUs must be available' | ||
' if this setting is enabled.') | ||
|
||
+flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.') | ||
+ | ||
FLAGS = flags.FLAGS | ||
|
||
MAX_TEMPLATE_HITS = 20 | ||
@@ -346,7 +348,8 @@ def main(argv): | ||
template_searcher=template_searcher, | ||
template_featurizer=template_featurizer, | ||
use_small_bfd=use_small_bfd, | ||
- use_precomputed_msas=FLAGS.use_precomputed_msas) | ||
+ use_precomputed_msas=FLAGS.use_precomputed_msas, | ||
+ n_parallel_msa=FLAGS.n_parallel_msa) | ||
|
||
if run_multimer_system: | ||
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model | ||
|
||
From 86c77759e10a1e1716be9af90627d1027b576502 Mon Sep 17 00:00:00 2001 | ||
From: fuji8 <[email protected]> | ||
Date: Sun, 12 Feb 2023 03:19:55 +0900 | ||
Subject: [PATCH 2/4] fix for v2.3.1 | ||
|
||
--- | ||
alphafold/data/pipeline.py | 19 +++++++++++++------ | ||
1 file changed, 13 insertions(+), 6 deletions(-) | ||
|
||
diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py | ||
index eeb37216a..eaea5fed2 100644 | ||
--- a/alphafold/data/pipeline.py | ||
+++ b/alphafold/data/pipeline.py | ||
@@ -154,8 +154,12 @@ def __init__(self, | ||
def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence): | ||
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') | ||
jackhmmer_uniref90_result = run_msa_tool( | ||
- self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path, | ||
- 'sto', self.use_precomputed_msas) | ||
+ msa_runner=self.jackhmmer_uniref90_runner, | ||
+ input_fasta_path=input_fasta_path, | ||
+ msa_out_path=uniref90_out_path, | ||
+ msa_format='sto', | ||
+ use_precomputed_msas=self.use_precomputed_msas, | ||
+ max_sto_sequences=self.uniref_max_hits) | ||
|
||
msa_for_templates = jackhmmer_uniref90_result['sto'] | ||
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates) | ||
@@ -186,11 +190,14 @@ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_outp | ||
def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): | ||
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') | ||
jackhmmer_mgnify_result = run_msa_tool( | ||
- self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto', | ||
- self.use_precomputed_msas) | ||
+ msa_runner=self.jackhmmer_mgnify_runner, | ||
+ input_fasta_path=input_fasta_path, | ||
+ msa_out_path=mgnify_out_path, | ||
+ msa_format='sto', | ||
+ use_precomputed_msas=self.use_precomputed_msas, | ||
+ max_sto_sequences=self.mgnify_max_hits) | ||
|
||
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | ||
- mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) | ||
return mgnify_msa | ||
|
||
def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): | ||
@@ -236,7 +243,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: | ||
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) | ||
futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir)) | ||
else: | ||
- with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: | ||
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: | ||
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) | ||
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) | ||
futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) | ||
|
||
From 6acf53caa74ca1b4900c2ccbc36ea33909965f38 Mon Sep 17 00:00:00 2001 | ||
From: fuji8 <[email protected]> | ||
Date: Sun, 12 Feb 2023 03:28:08 +0900 | ||
Subject: [PATCH 3/4] rename func name | ||
|
||
--- | ||
alphafold/data/pipeline.py | 4 ++-- | ||
1 file changed, 2 insertions(+), 2 deletions(-) | ||
|
||
diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py | ||
index eaea5fed2..ca8292a53 100644 | ||
--- a/alphafold/data/pipeline.py | ||
+++ b/alphafold/data/pipeline.py | ||
@@ -200,7 +200,7 @@ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir): | ||
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | ||
return mgnify_msa | ||
|
||
- def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir): | ||
+ def hhblits_bfd_uniref_caller(self, input_fasta_path, msa_output_dir): | ||
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m') | ||
hhblits_bfd_uniref_result = run_msa_tool( | ||
msa_runner=self.hhblits_bfd_uniref_runner, | ||
@@ -246,7 +246,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: | ||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor: | ||
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence)) | ||
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir)) | ||
- futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir)) | ||
+ futures.append(executor.submit(self.hhblits_bfd_uniref_caller, input_fasta_path, msa_output_dir)) | ||
|
||
uniref90_msa, pdb_template_hits = futures[0].result() | ||
mgnify_msa = futures[1].result() | ||
|
||
From 36ea35491b40b1104e043dcba7eca7463566b69a Mon Sep 17 00:00:00 2001 | ||
From: fuji8 <[email protected]> | ||
Date: Sun, 12 Feb 2023 18:28:28 +0000 | ||
Subject: [PATCH 4/4] fix for docker | ||
|
||
--- | ||
docker/run_docker.py | 2 ++ | ||
1 file changed, 2 insertions(+) | ||
|
||
diff --git a/docker/run_docker.py b/docker/run_docker.py | ||
index 155d8fe2c..e7e2024ef 100644 | ||
--- a/docker/run_docker.py | ||
+++ b/docker/run_docker.py | ||
@@ -93,6 +93,7 @@ | ||
'will be owned by this user:group. By default, this is the current user. ' | ||
'Valid options are: uid or uid:gid, non-numeric values are not recognised ' | ||
'by Docker unless that user has been created within the container.') | ||
+flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
@@ -227,6 +228,7 @@ def main(argv): | ||
f'--models_to_relax={FLAGS.models_to_relax}', | ||
f'--use_gpu_relax={use_gpu_relax}', | ||
'--logtostderr', | ||
+ f'--n_parallel_msa={FLAGS.n_parallel_msa}' | ||
]) | ||
|
||
client = docker.from_env() |