Skip to content
cody-mar10 edited this page Nov 7, 2024 · 5 revisions

PST user guide

Summary

The Protein Set Transformer (PST) is a protein-based genome language model for contextualizing protein language model embeddings with genome context and subsequently producing genome embeddings from these protein embeddings.

Installation

We plan to create a pip-installable package in the future but are having issues with a custom fork dependency.

For now, you can install the software dependencies of PST using a combination of mamba and pip, which should take no more than 5 minutes.

Note: you will likely need to link your git command line interface with an online github account. Follow this link for help setting up git at the command line.

Without GPUs

# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' cpuonly pyg pytorch-scatter

mamba activate pst

# install latest updates from this repository
# best to clone the repo since you may want to run the test demo
git clone https://github.com/cody-mar10/protein_set_transformer.git

cd protein_set_transformer

pip install . #<- notice the [dot]

With GPUs

# setup torch first -- conda does this so much better than pip
mamba create -n pst -c pytorch -c nvidia -c pyg -c conda-forge 'python<3.12' 'pytorch>=2.0' pytorch-cuda=11.8 pyg pytorch-scatter

mamba activate pst

# install latest updates from this repository
# best to clone the repo since you may want to run the test demo
git clone https://github.com/cody-mar10/protein_set_transformer.git

cd protein_set_transformer

pip install . #<- notice the [dot]

Installing for training a new PST

We implemented a hyperparameter tuning cross validation workflow implemented using Lightning Fabric in a base library called lightning-crossval. Part of our specific implementation for hyperparameter tuning is also implemented in the PST library.

If you want to include the optional dependendings for training a new PST, you can follow the corresponding installation steps above with the following change:

pip install .[tune]

Test run

Upon successful installation, you will have the pst executable to train, tune, and predict. There are also other modules included as utilties that you can see using pst -h.

You will need to first download a trained vPST model:

pst download --trained-models

This will download both vPST models into ./pstdata, but you can change the download location using --outdir.

You can use the test data for a test prediction run:

pst predict \
    --file test/test_data.graphfmt.h5 \ # this is in the git repo
    --checkpoint pstdata/pst-small_trained_model.ckpt \
    --outdir test_run

The results from the above command are available at test/test_run/predictions.h5. This test run takes fewer than 1 minute using a single CPU.

If you are unfamiliar with .h5 files, you can use pytables (installed with PST as a dependency) to inspect .h5 files in python, or you can install hdf5 and use the h5ls to inspect the fields in the output file.

There should be 3 fields in the prediciton file:

  1. attn which contains the per-protein attention values (shape: $N_{prot} \times N_{heads}$)
  2. ctx_ptn which contains the contextualized PST protein embeddings (shape: $N_{prot} \times D$)
  3. genome which contains the PST genome embeddings (shape: $N_{genome} \times D$)
    • Prior to version 1.2.0, this was called data.

Data availability

All data associated with the initial training model training can be found here: https://doi.org/10.5061/dryad.d7wm37q8w

We have provided the README to the DRYAD data repository to render here. Additionally, we have provided a programmatic way to access the data from the command line using pst download:

usage: pst download [-h] [--all] [--outdir PATH] [--esm-large] [--esm-small] [--vpst-large] [--vpst-small] [--genome] [--genslm]
                    [--trained-models] [--genome-clusters] [--protein-clusters] [--aai] [--fasta] [--host-prediction] [--no-readme]
                    [--supplementary-data] [--supplementary-tables]

help:
  -h, --help            show this help message and exit

DOWNLOAD:
  --all                 download all files from the DRYAD repository (default: False)
  --outdir PATH         output directory to save files (default: ./pstdata)

EMBEDDINGS:
  --esm-large           download ESM2 large [t33_150M] PROTEIN embeddings for training and test viruses (esm-large_protein_embeddings.tar.gz)
                        (default: False)
  --esm-small           download ESM2 small [t6_8M] PROTEIN embeddings for training and test viruses (esm-small_protein_embeddings.tar.gz)
                        (default: False)
  --vpst-large          download vPST large PROTEIN embeddings for training and test viruses (pst-large_protein_embeddings.tar.gz) (default:
                        False)
  --vpst-small          download vPST small PROTEIN embeddings for training and test viruses (pst-small_protein_embeddings.tar.gz) (default:
                        False)
  --genome              download all genome embeddings for training and test viruses (genome_embeddings.tar.gz) (default: False)
  --genslm              download GenSLM ORF embeddings (genslm_protein_embeddings.tar.gz) (default: False)

TRAINED_MODELS:
  --trained-models      download trained vPST models (trained_models.tar.gz) (default: False)

CLUSTERS:
  --genome-clusters     download genome cluster labels (genome_clusters.tar.gz) (default: False)
  --protein-clusters    download protein cluster labels (protein_clusters.tar.gz) (default: False)

MANUSCRIPT_DATA:
  --aai                 download intermediate files for AAI calculations in the manuscript (aai.tar.gz) (default: False)
  --fasta               download protein fasta files for training and test viruses (fasta.tar.gz) (default: False)
  --host-prediction     download all data associated with the host prediction proof of concept (host_prediction.tar.gz) (default: False)
  --no-readme           download the DRYAD README (README.md) (default: True)
  --supplementary-data  download supplementary data directly used to make the figures in the manuscript (supplementary_data.tar.gz) (default:
                        False)
  --supplementary-tables
                        download supplementary tables (supplementary_tables.zip) (default: False)

For flags relating to the download of specific files, you can add as many flags as you like. For example, if you want the trained models and the raw FASTA files used to train the vPSTs downloaded into a directory called pst_models, then you'd run this:

pst download --trained-models --fasta --outdir pst_models

Usage

Embedding new proteins and genomes

1. ESM2 protein embeddings

The minimum input to the PST framework is a protein FASTA file, which we prefer to generate for microbes and viruses using pyrodigal. We have provided the pst embed command to embed protein sequences using the ESM2 models.

Here is what ESM2 models are used for each vPST model:

vPST ESM2
pst-small esm2_t30_150M_UR50D
pst-large esm2_t6_8M_UR50D

To embed the protein sequences from a FASTA file, use the following command depending on which vPST model you are using:

### for pst-small
pst embed --input FASTAFILE.faa --esm esm2_t6_8M 

### for pst-large 
pst embed --input FASTAFILE.faa --esm esm2_t30_150M

pst embed has other options to change the output directory (--outdir), change the ESM2 model download directory (--torch-hub), and number of CPU threads or GPU devices (--devices).

The output of pst embed is a single .h5 file with the field data that stores the protein embeddings.

FASTA File requirements

The protein embeddings from pst embed are produced IN THE SAME ORDER as the sequences in the fASTA file. Thus, the following are required of the input FASTA file:

  1. The file must be sorted to group all proteins from the same genome together
  2. For the block of proteins from each genome, the proteins must be in order of their appearance in the genome.
  3. The FASTA headers must look like this: scaffold_#, where scaffold is the genome scaffold name and # is the protein numerical ID relative to each scaffold. (This is the typical output from prodigal/pyrodigal -- in fact, the additional information in the prodigal-style headers is needed for the next step.)
    • In the event that you have multi-scaffold viruses (vMAGs, etc.), you can either manually orient the scaffolds and renumber the proteins to contiguously count from the first scaffold to the last. This is what was done with the test dataset in the manuscript.
      • We provided a utility script pst graphify to do this if an input mapping from scaffolds to genomes is provided. See next section.
    • TODO: We are exploring a more native solution for multi-scaffold viruses that does not require an arbitrary arrangement of scaffolds that should not require changes to the model.

2. Convert protein embeddings to graph format

Use the pst graphify command to convert the ESM2 protein embeddings into graph format. You will need to protein FASTA file used to generate the embeddings, since the embeddings should be in the same order as the FASTA file. The FASTA file should be in prodigal format: >scaffold_ptnid # start # stop # strand ....

If your FASTA headers have the above format, you can use the following command:

pst graphify --file EMBEDDINGSFILE.h5 --fasta-file FASTAFILE.faa

If you did not keep the extra metadata on the headers, you can alternatively provide a simple tab-delimited mapping file (--strand-file) that maps each protein name to its strand (-1 or 1 only):

genome1_1   1
genome1_2   1
genome1_3   1
genome1_4   -1

Further, if you have multi-scaffold viruses, you can provide a tab-delimited file (--scaffold-map-file) that maps the scaffold name to the genome name to count all proteins from the entire genome instead of each scaffold:

scaffoldA   genome1
scaffoldB   genome1
scaffoldC   genome2
scaffoldD   genome2

API

When installing the ptn-set-transformer library from this repository, the model and datamodule classes are available from the pst namespace.

Model API

The primary classes needed are the ProteinSetTransformer, which is a subclass of the PyTorch Lightning lightning.LightningModule. Thus, ProteinSetTransformer has the following methods common to LightningModule that can be overwritten:

  • training_step
  • predict_step
  • configure_optimizers
  • forward

If dramatic changes to the training setup or objective are desired, you will want to subclass the BaseProteinSetTransformer to define these changes. See the finetuning section for more information.

ProteinSetTransformer is a wrapper around the SetTransformer (defined in pst.nn.models) class which contains the encoder-decoder layers. SetTransformer internally uses PyTorch Geometric graph formatting. Graph formatted data account for the fact that each genome encodes a different number of proteins, as graphs would contain different numbers of nodes.

The forward method of the SetTransformer first encodes the protein embeddings using self-attention along edges of adjacent proteins (based on the order of encoding in the genome). Each encoding layer MultiheadAttentionConv is defined in pst.nn.layers and uses the standard scaled dot product attention with residual connections.

Then, these contextualized protein embeddings are decoded using a MultiheadAttentionPooling layer. This layer has a learnable d-dimensional seed vector that is used as the query along with the protein embeddings as the key and value for scaled dot product attention. The attention weights from the attention project are softmax normalized and used to weight each protein embedding for a weighted average, producing a genome embedding.

Model Config

To instantiate a ProteinSetTransformer, the Pydantic config ModelConfig is needed which has the following schema:

### Note: these are defined in pst.nn.config

from pydantic import BaseModel

class ModelConfig(BaseModel):
    in_dim: int
    out_dim: int
    num_heads: int
    n_enc_layers: int
    embed_scale: int
    dropout: float
    layer_dropout: float
    proj_cat: bool
    compile: bool
    optimizer: OptimizerConfig
    loss: LossConfig
    augmentation: AugmentationConfig

class OptimizerConfig(BaseModel):
    lr: float
    weight_decay: float
    betas: tuple[float, float]
    warmup_steps: int
    use_scheduler: bool

class AugmentationConfig(BaseModel):
    sample_rate: float

class LossConfig(BaseModel):
    margin: float
    sample_scale: float
    no_negatives_mode: NO_NEGATIVES_MODES

Instantiating a PST

With a ModelConfig instance, you can create a PST model like this:

from pst import ModelConfig
from pst import ProteinSetTransformer as PST

# ideally this is read from some external info like command line values
config = ModelConfig.default()
model = PST(config)

Alternatively, if you are starting from a pretrained model checkpoint, then you can instantiate a model that uses these pretrained weights like this:

from pst import ProteinSetTransformer as PST

checkpoint_file = "" # should be a real file path
model = PST.from_pretrained(checkpoint_file)

The model config is internally created by the .from_pretrained class method since the attributes are stored in the trained model's checkpoint.

Datamodule API and graph-formatted data

The forward method (and any other data-facing method) of ProteinSetTransformer require graph-formatted data. We use the standards set by PyTorch Geometric to model our pst.GenomeGraph after. In this setup, each genome is viewed as a graph (pst.GenomeGraph), and proteins are nodes in this graph. The protein nodes are connected if they are adjacently encoded in the genome (subject to hyperparameters). Each protein node is represented by its corresponding protein embedding (x).

The attributes of a GenomeGraph object look like this:

class GenomeGraph:
    x: torch.Tensor # shape: [N, d] <- protein embeddings
    edge_index: torch.Tensor # shape: [2, E] <- define protein-protein connections
    num_proteins: int # <- number of proteins encoded by this genome, ie number of nodes
    class_id: int # <- optional class ID for genome/graph level class
    strand: torch.Tensor # shape: [N] <- strand of each protein
    pos: torch.Tensor # shape: [N, 1] <- integer tensor that counts from 0 to N-1
    y: torch.Tensor | None # <- optional label tensor for this genome graph

For a collection of GenomeGraphs, such as in a minibatch, each protein embedding tensor from each genome are stacked. We use an index pointer (ptr) to keep track of the start and stop positions for the proteins belonging to each genome. This allows efficient random access.

The pst.GenomeDataset handles creating batches of GenomeGraphs. However, the specific batch object is a GenomeDataBatch with the following fields:

class GenomeGraphBatch:
    x: torch.Tensor # shape: [N, d] <- stacked protein embeddings
    edge_index: torch.Tensor # shape: [2, E] <- define protein-protein connections
    num_proteins: torch.Tensor # <- number of proteins encoded by each genome
    class_id: torch.Tensor # <- optional class ID for genome/graph level class for each genome
    strand: torch.Tensor # shape: [N] <- strand of each protein in all genomes
    pos: torch.Tensor # shape: [N, 1] <- integer tensor that counts from 0 to N-1 for each genome
    y: torch.Tensor | None # <- optional label tensor for each genome
    ### new fields
    ptr: torch.Tensor # shape: [num genomes + 1] <- index pointer to compute offsets
    batch: torch.Tensor # shape: [N] <- assigns each protein node to a unique genome ID

The protein-protein edges (edge_index) for each genome graph are computed on the fly upon instantiating a GenomeDataset. This is because there are 2 hyperparameters that control the connectivity of the genome graphs: chunk_size and threshold. chunk_size determines the total number of nodes that can belong to subgraphs, while threshold determines the maximum distance in number of proteins that each protein will be connected to. A threshold value of -1 indicates no maximum distance and leads to each subgraph being fully connected. The value of chunk_size determines the number of subgraphs for each genome: How genome graph chunking works

This chunking is used for memory efficiency but may also reflect the real evolutionary pressures of genes that are encoded near each other. This is also support for sparsifying these subgraphs by changing the value to the --threshold command line option, which gets sent to the DataConfig.threshold parameter. This was not used for pretraining the original vPSTs, but this option is available.

To store data that fits the above data models, .h5 files are required with the following fields:

dataset.h5:
  - data: stores the stacked protein embeddings (maps to x)
  - ptr: offsets needed to randomly access all proteins from each genome from the data field
  - sizes: number of proteins for each genome
  - strand: protein encoding strand for each protein

Note that class_id is optional. If not provided in the .h5 file, all genomes will default to the same class (which will probably not be used pending the model's training loop).

If you follow the information provided above for generating protein embeddings from FASTA files and converting these embeddings to the required graph format, your file format should be taken care of.

Finetuning vPST

Our model class ProteinSetTransformer is a lightning.LightningModule subclass from PyTorch Lightning.

We make use of Pydantic schema models as the config for our model, but you can load a pretrained model from a PyTorch checkpoint like this:

from pst.nn.modules import ProteinSetTransformer as PST

ckptfile = "pst-small_trained_model.ckpt"
model = PST.from_pretrained(ckptfile)

Similarly, to load a pretained GenomeDatamodule, you can do the following. Note: the chunk_size for each genome graph were tuned in the pretrained vPSTs.

from pst.data.modules import GenomeDatamodule

ckptfile = "pst-small_trained_model.ckpt"
new_data_file = "new_data.graphfmt.h5
datamodule = GenomeDatamodule.from_pretrained(ckptfile, new_data_file)

You still need to give the file location for the graph-formatted .h5 file.

TODO: describe the pst.GenomeDataModule class. <- this might be better in a pytorch lightning integration section

With the same genome-level triplet loss objective

TODO: We are adding a finetuning command line mode for this since this is basically the same as training a model but starting from a pretrained model.

With a new loss objective

We have provided a code example in with this repository at examples/finetuning.ipynb. This covers the case of a new objective that focuses on either genome- or protein-level tasks.

In brief, you need to subclass either pst.BaseProteinSetTransformer for genome tasks (or dual genome/protein tasks) OR pst.BaseProteinSetTransformerEncoder for protein-only tasks.

The subclassed models must define the following methods:

  1. setup_objective - should return a callable that can be used to compute the loss
    • If the loss function requires a tunable state that, such as the margin and scaling factor of triplet loss, then a custom loss and model config can be defined using pst.BaseLossConfig and pst.BaseModelConfig, respectively.
    • The custom loss config must be used to override the loss field in the custom model config.
  2. forward - which defines the model's forward pass, including the data handling and loss computation

Optionally but probably recommended, you will need to update the custom model's __init__ method to define new trainable layers needed for the new model objective.

Training a new (genomic) PST

All functionality to do this is embedded in the pst train command line mode.

To tune hyperparameters, there is also the pst tune command line mode that leverages a custom library we built on Lightning Fabric called Lightning CrossVal. This library enables epoch synchronized hyperparameter tuning through cross validation. The cross validation strategy can be defined using the lightning-cv framework.

TODO: better description of the training and hyper parameter tuning process.