Skip to content

Commit

Permalink
Fix orientation batching (#58)
Browse files Browse the repository at this point in the history
* Fix orientation batching

* Vectorize orientations feature computation

* Comment out debugging code

* Simplify code

* update changelog

---------

Co-authored-by: Arian Jamasb <[email protected]>
  • Loading branch information
amorehead and a-r-j authored Dec 29, 2023
1 parent 14e33c8 commit 99c6f81
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
### Features

* Improves positional encoding performance by adding a `seq_pos` attribute on `Data/Protein` objects in the base dataset getter. [#53](https://github.com/a-r-j/ProteinWorkshop/pull/53/)
* Ensure correct batched computation of orientation features. [#58](https://github.com/a-r-j/ProteinWorkshop/pull/58/)

### Models

Expand Down
3 changes: 1 addition & 2 deletions proteinworkshop/config/visualise.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# === 1. Set config parameters ===
name: "" # default name for the experiment, "" means logger (eg. wandb) will generate a unique name
seed: 52 # seed for random number generators in pytorch, numpy and python.random
seed: 52 # seed for random number generators in pytorch, numpy and python.random (as well as in UMAP)
num_workers: 16 # number of subprocesses to use for data loading.

# === 2. Specify defaults here. Defaults will be overwritten by equivalently named options in this file ===
Expand All @@ -29,7 +29,6 @@ compile: True
# simply provide checkpoint path and plot filepath to embed dataset and plot its UMAP embeddings
ckpt_path: null # path to checkpoint to load
plot_filepath: null # path to which to save embeddings plot
seed: 42 # random seed to be used by the UMAP algorithm
use_cuda_device: True # if True, use an available CUDA device for embedding generation
cuda_device_index: 0 # if CUDA devices are targeted and available, which available CUDA device to use for embedding generation

Expand Down
2 changes: 1 addition & 1 deletion proteinworkshop/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,9 @@ def create_example_batch(n: int = 4) -> ProteinBatch:
batch.pos = batch.coords[:, 1, :]
batch.x = F.one_hot(batch.residue_type, num_classes=23).float()

batch.x_vector_attr = orientations(batch.pos, batch._slice_dict["coords"])
batch.graph_y = torch.randint(0, 2, (n, 1))

batch.x_vector_attr = orientations(batch.pos)
batch.edge_attr = pos_emb(batch.edge_index, 9)
batch.edge_vector_attr = _normalize(
batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]]
Expand Down
44 changes: 39 additions & 5 deletions proteinworkshop/features/node_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def compute_vector_node_features(
vector_node_features = []
for feature in vector_features:
if feature == "orientation":
vector_node_features.append(orientations(x.coords))
vector_node_features.append(orientations(x.coords, x._slice_dict["coords"]))
elif feature == "virtual_cb_vector":
raise NotImplementedError("Virtual CB vector not implemented yet.")
else:
Expand Down Expand Up @@ -149,12 +149,46 @@ def compute_surface_feat(

@jaxtyped(typechecker=typechecker)
def orientations(
X: Union[CoordTensor, AtomTensor], ca_idx: int = 1
X: Union[CoordTensor, AtomTensor], coords_slice_index: torch.Tensor, ca_idx: int = 1
) -> OrientationTensor:
if X.ndim == 3:
X = X[:, ca_idx, :]
forward = _normalize(X[1:] - X[:-1])
backward = _normalize(X[:-1] - X[1:])

# NOTE: the first item in the coordinates slice index is always 0,
# and the last item is always the node count of the batch
batch_num_nodes = X.shape[0]
slice_index = coords_slice_index[1:] - 1
last_node_index = slice_index[:-1]
first_node_index = slice_index[:-1] + 1
slice_mask = torch.zeros(batch_num_nodes - 1, dtype=torch.bool)
last_node_forward_slice_mask = slice_mask.clone()
first_node_backward_slice_mask = slice_mask.clone()

# NOTE: all of the last (first) nodes in a subgraph have their
# forward (backward) vectors set to a padding value (i.e., 0.0)
# to mimic feature construction behavior with single input graphs
forward_slice = X[1:] - X[:-1]
backward_slice = X[:-1] - X[1:]
last_node_forward_slice_mask[last_node_index] = True
first_node_backward_slice_mask[first_node_index - 1] = True # NOTE: for the backward slices, our indexing defaults to node index `1`
forward_slice[last_node_forward_slice_mask] = 0.0 # NOTE: this handles all but the last node in the last subgraph
backward_slice[first_node_backward_slice_mask] = 0.0 # NOTE: this handles all but the first node in the first subgraph

# NOTE: padding first and last nodes with zero vectors does not impact feature normalization
forward = _normalize(forward_slice)
backward = _normalize(backward_slice)
forward = F.pad(forward, [0, 0, 0, 1])
backward = F.pad(backward, [0, 0, 1, 0])
return torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2)
orientations = torch.cat((forward.unsqueeze(-2), backward.unsqueeze(-2)), dim=-2)

# optionally debug/verify the orientations
# last_node_indices = torch.cat((last_node_index, torch.tensor([batch_num_nodes - 1])), dim=0)
# first_node_indices = torch.cat((torch.tensor([0]), first_node_index), dim=0)
# intermediate_node_indices_mask = torch.ones(batch_num_nodes, device=X.device, dtype=torch.bool)
# intermediate_node_indices_mask[last_node_indices] = False
# intermediate_node_indices_mask[first_node_indices] = False
# assert not orientations[last_node_indices][:, 0].any() and orientations[last_node_indices][:, 1].any()
# assert orientations[first_node_indices][:, 0].any() and not orientations[first_node_indices][:, 1].any()
# assert orientations[intermediate_node_indices_mask][:, 0].any() and orientations[intermediate_node_indices_mask][:, 1].any()

return orientations
1 change: 0 additions & 1 deletion proteinworkshop/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import torch.nn as nn
import torch_geometric
from graphein.protein.tensor.dataloader import ProteinDataLoader
from graphein.ml.datasets.foldcomp_dataset import FoldCompLightningDataModule
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger
from loguru import logger as log
Expand Down

0 comments on commit 99c6f81

Please sign in to comment.