Skip to content

Commit

Permalink
Flags for MSA search only and parallel MSA search
Browse files Browse the repository at this point in the history
  • Loading branch information
VRehnberg committed Apr 25, 2024
1 parent 2797920 commit fe59f1b
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ components = [
'AlphaFold-2.1.0_fix-scp-path.patch',
'AlphaFold-2.0.1_setup_rm_tfcpu.patch',
'AlphaFold-2.3.1_use_openmm_7.7.0.patch',
'AlphaFold-2.3.1_add-run_features_only-option.patch',
'AlphaFold-2.3.1_parallel-execution-of-MSA-tools.patch',
],
'checksums': [
'1161b2609fa896b16399b900ec2b813e5a0b363fe4e2b26bd826953ba234736a', # v2.3.1.tar.gz
Expand All @@ -72,7 +74,10 @@ components = [
'1a2e4e843bd9a4d15ee39e6c37cc63ba281311cc7a0a5610f0e43b52ef93faac', # AlphaFold-2.0.1_setup_rm_tfcpu.patch
# AlphaFold-2.3.1_use_openmm_7.7.0.patch
'd800bb085deac7edbe2d72916c1194043964aaf51b88a3b5a5016ab68a1090ec',

# AlphaFold-2.3.1_add-run_features_only-option.patch
'9c39fcec5532fabc2960f9d94ae8434f34c7da5ade4f67ae37c6906143f0132e',
# AlphaFold-2.3.1_parallel-execution-of-MSA-tools.patch
'b1db9e200e36df460cc985552e4549c726efdeb94abcfa55c522f48e78caa66d',
],
'start_dir': 'alphafold-%(version)s',
'use_pip': True,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
Adds a flat `--run_feature_only` and some logic that means that if
`features.pkl` already exists, this step will be skipped.

To a large degree taken from https://github.com/Zuricho/ParallelFold/

diff --git a/run_alphafold.py b/run_alphafold.py
index 0d89bfb..04016c5 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -141,6 +141,8 @@ flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')
+flags.DEFINE_boolean('run_feature_only', False, 'Calculate MSA and template to generate '
+ 'feature and then stop.')

FLAGS = flags.FLAGS

@@ -180,7 +182,8 @@ def predict_structure(
amber_relaxer: relax.AmberRelaxation,
benchmark: bool,
random_seed: int,
- models_to_relax: ModelsToRelax):
+ models_to_relax: ModelsToRelax,
+ run_feature_only: bool):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {}
@@ -193,16 +196,27 @@ def predict_structure(

# Get features.
t_0 = time.time()
- feature_dict = data_pipeline.process(
- input_fasta_path=fasta_path,
- msa_output_dir=msa_output_dir)
- timings['features'] = time.time() - t_0
+ features_output_path = os.path.join(output_dir, 'features.pkl')
+
+ # If we already have feature.pkl file, skip the MSA and template finding step
+ if os.path.exists(features_output_path):
+ feature_dict = pickle.load(open(features_output_path, 'rb'))
+
+ else:
+ feature_dict = data_pipeline.process(
+ input_fasta_path=fasta_path,
+ msa_output_dir=msa_output_dir)

# Write out features as a pickled dictionary.
features_output_path = os.path.join(output_dir, 'features.pkl')
with open(features_output_path, 'wb') as f:
pickle.dump(feature_dict, f, protocol=4)

+ timings['features'] = time.time() - t_0
+
+ if run_feature_only: # if not run_feature, skip the rest of the function
+ return 0
+
unrelaxed_pdbs = {}
unrelaxed_proteins = {}
relaxed_pdbs = {}
@@ -449,7 +463,8 @@ def main(argv):
amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark,
random_seed=random_seed,
- models_to_relax=FLAGS.models_to_relax)
+ models_to_relax=FLAGS.models_to_relax,
+ run_feature_only=FLAGS.run_feature_only)


if __name__ == '__main__':
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
MSA search is done by three tools, this adds an option to do these in
parallel. From https://github.com/google-deepmind/alphafold/pull/399

From 5e18b3ab5cf9b6e20ec3bd8f0f2a0c02ba9db837 Mon Sep 17 00:00:00 2001
From: fuji <[email protected]>
Date: Tue, 22 Feb 2022 10:16:10 +0900
Subject: [PATCH 1/4] Add parallel runs of MSA tools

---
alphafold/data/pipeline.py | 110 ++++++++++++++++++++++---------------
run_alphafold.py | 5 +-
2 files changed, 70 insertions(+), 45 deletions(-)

diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py
index c92944fe3..3a3a18d4a 100644
--- a/alphafold/data/pipeline.py
+++ b/alphafold/data/pipeline.py
@@ -27,6 +27,8 @@
from alphafold.data.tools import jackhmmer
import numpy as np

+import concurrent.futures
+
# Internal import (7716).

FeatureDict = MutableMapping[str, np.ndarray]
@@ -124,7 +126,8 @@ def __init__(self,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
- use_precomputed_msas: bool = False):
+ use_precomputed_msas: bool = False,
+ n_parallel_msa: int = 1):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
@@ -146,35 +149,13 @@ def __init__(self,
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas
+ self.n_parallel_msa = n_parallel_msa

- def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
- """Runs alignment tools on the input sequence and creates features."""
- with open(input_fasta_path) as f:
- input_fasta_str = f.read()
- input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
- if len(input_seqs) != 1:
- raise ValueError(
- f'More than one input sequence found in {input_fasta_path}.')
- input_sequence = input_seqs[0]
- input_description = input_descs[0]
- num_res = len(input_sequence)
-
+ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence):
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
- msa_runner=self.jackhmmer_uniref90_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=uniref90_out_path,
- msa_format='sto',
- use_precomputed_msas=self.use_precomputed_msas,
- max_sto_sequences=self.uniref_max_hits)
- mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
- jackhmmer_mgnify_result = run_msa_tool(
- msa_runner=self.jackhmmer_mgnify_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=mgnify_out_path,
- msa_format='sto',
- use_precomputed_msas=self.use_precomputed_msas,
- max_sto_sequences=self.mgnify_max_hits)
+ self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
+ 'sto', self.use_precomputed_msas)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
@@ -196,29 +177,70 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
- mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
+ uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)
+ return uniref90_msa, pdb_template_hits
+
+ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
+ mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
+ jackhmmer_mgnify_result = run_msa_tool(
+ self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
+ self.use_precomputed_msas)
+
+ mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
+ mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
+ return mgnify_msa
+
+ def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir):
+ bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
+ hhblits_bfd_uniclust_result = run_msa_tool(
+ self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
+ 'a3m', self.use_precomputed_msas)
+ bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
+ return bfd_msa
+
+ def jackhmmer_small_bfd_caller(self, input_fasta_path, msa_output_dir):
+ bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
+ jackhmmer_small_bfd_result = run_msa_tool(
+ msa_runner=self.jackhmmer_small_bfd_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=bfd_out_path,
+ msa_format='sto',
+ use_precomputed_msas=self.use_precomputed_msas)
+ bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
+ return bfd_msa
+
+
+ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
+ """Runs alignment tools on the input sequence and creates features."""
+ with open(input_fasta_path) as f:
+ input_fasta_str = f.read()
+ input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
+ if len(input_seqs) != 1:
+ raise ValueError(
+ f'More than one input sequence found in {input_fasta_path}.')
+ input_sequence = input_seqs[0]
+ input_description = input_descs[0]
+ num_res = len(input_sequence)

+
+ futures = []
if self._use_small_bfd:
- bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
- jackhmmer_small_bfd_result = run_msa_tool(
- msa_runner=self.jackhmmer_small_bfd_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=bfd_out_path,
- msa_format='sto',
- use_precomputed_msas=self.use_precomputed_msas)
- bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
+ futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir))
else:
- bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
- hhblits_bfd_uniclust_result = run_msa_tool(
- msa_runner=self.hhblits_bfd_uniclust_runner,
- input_fasta_path=input_fasta_path,
- msa_out_path=bfd_out_path,
- msa_format='a3m',
- use_precomputed_msas=self.use_precomputed_msas)
- bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
+ futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
+ futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
+ futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir))
+
+ uniref90_msa, pdb_template_hits = futures[0].result()
+ mgnify_msa = futures[1].result()
+ bfd_msa = futures[2].result()

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
diff --git a/run_alphafold.py b/run_alphafold.py
index d87e00e4b..6967d0c9c 100644
--- a/run_alphafold.py
+++ b/run_alphafold.py
@@ -129,6 +129,8 @@
'recommended to enable if possible. GPUs must be available'
' if this setting is enabled.')

+flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')
+
FLAGS = flags.FLAGS

MAX_TEMPLATE_HITS = 20
@@ -346,7 +348,8 @@ def main(argv):
template_searcher=template_searcher,
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
- use_precomputed_msas=FLAGS.use_precomputed_msas)
+ use_precomputed_msas=FLAGS.use_precomputed_msas,
+ n_parallel_msa=FLAGS.n_parallel_msa)

if run_multimer_system:
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model

From 86c77759e10a1e1716be9af90627d1027b576502 Mon Sep 17 00:00:00 2001
From: fuji8 <[email protected]>
Date: Sun, 12 Feb 2023 03:19:55 +0900
Subject: [PATCH 2/4] fix for v2.3.1

---
alphafold/data/pipeline.py | 19 +++++++++++++------
1 file changed, 13 insertions(+), 6 deletions(-)

diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py
index eeb37216a..eaea5fed2 100644
--- a/alphafold/data/pipeline.py
+++ b/alphafold/data/pipeline.py
@@ -154,8 +154,12 @@ def __init__(self,
def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_output_dir, input_sequence):
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
- self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
- 'sto', self.use_precomputed_msas)
+ msa_runner=self.jackhmmer_uniref90_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=uniref90_out_path,
+ msa_format='sto',
+ use_precomputed_msas=self.use_precomputed_msas,
+ max_sto_sequences=self.uniref_max_hits)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
@@ -186,11 +190,14 @@ def jackhmmer_uniref90_and_pdb_templates_caller(self, input_fasta_path, msa_outp
def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
- self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
- self.use_precomputed_msas)
+ msa_runner=self.jackhmmer_mgnify_runner,
+ input_fasta_path=input_fasta_path,
+ msa_out_path=mgnify_out_path,
+ msa_format='sto',
+ use_precomputed_msas=self.use_precomputed_msas,
+ max_sto_sequences=self.mgnify_max_hits)

mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
- mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)
return mgnify_msa

def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir):
@@ -236,7 +243,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.jackhmmer_small_bfd_caller, input_fasta_path, msa_output_dir))
else:
- with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir))

From 6acf53caa74ca1b4900c2ccbc36ea33909965f38 Mon Sep 17 00:00:00 2001
From: fuji8 <[email protected]>
Date: Sun, 12 Feb 2023 03:28:08 +0900
Subject: [PATCH 3/4] rename func name

---
alphafold/data/pipeline.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/alphafold/data/pipeline.py b/alphafold/data/pipeline.py
index eaea5fed2..ca8292a53 100644
--- a/alphafold/data/pipeline.py
+++ b/alphafold/data/pipeline.py
@@ -200,7 +200,7 @@ def jackhmmer_mgnify_caller(self, input_fasta_path, msa_output_dir):
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
return mgnify_msa

- def hhblits_bfd_uniclust_caller(self, input_fasta_path, msa_output_dir):
+ def hhblits_bfd_uniref_caller(self, input_fasta_path, msa_output_dir):
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniref_hits.a3m')
hhblits_bfd_uniref_result = run_msa_tool(
msa_runner=self.hhblits_bfd_uniref_runner,
@@ -246,7 +246,7 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_parallel_msa) as executor:
futures.append(executor.submit(self.jackhmmer_uniref90_and_pdb_templates_caller, input_fasta_path, msa_output_dir, input_sequence))
futures.append(executor.submit(self.jackhmmer_mgnify_caller, input_fasta_path, msa_output_dir))
- futures.append(executor.submit(self.hhblits_bfd_uniclust_caller, input_fasta_path, msa_output_dir))
+ futures.append(executor.submit(self.hhblits_bfd_uniref_caller, input_fasta_path, msa_output_dir))

uniref90_msa, pdb_template_hits = futures[0].result()
mgnify_msa = futures[1].result()

From 36ea35491b40b1104e043dcba7eca7463566b69a Mon Sep 17 00:00:00 2001
From: fuji8 <[email protected]>
Date: Sun, 12 Feb 2023 18:28:28 +0000
Subject: [PATCH 4/4] fix for docker

---
docker/run_docker.py | 2 ++
1 file changed, 2 insertions(+)

diff --git a/docker/run_docker.py b/docker/run_docker.py
index 155d8fe2c..e7e2024ef 100644
--- a/docker/run_docker.py
+++ b/docker/run_docker.py
@@ -93,6 +93,7 @@
'will be owned by this user:group. By default, this is the current user. '
'Valid options are: uid or uid:gid, non-numeric values are not recognised '
'by Docker unless that user has been created within the container.')
+flags.DEFINE_integer('n_parallel_msa', 1, 'Number of parallel runs of MSA tools.')

FLAGS = flags.FLAGS

@@ -227,6 +228,7 @@ def main(argv):
f'--models_to_relax={FLAGS.models_to_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr',
+ f'--n_parallel_msa={FLAGS.n_parallel_msa}'
])

client = docker.from_env()

0 comments on commit fe59f1b

Please sign in to comment.