Skip to content

Commit

Permalink
Comments for smtb module
Browse files Browse the repository at this point in the history
  • Loading branch information
Old-Shatterhand committed Sep 29, 2024
1 parent 4e2df12 commit 0406e4f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 14 deletions.
18 changes: 18 additions & 0 deletions smtb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@


def collate_fn(batch: list[tuple[torch.Tensor, float]]) -> tuple[torch.Tensor, torch.Tensor]:
"""
Collate function for downstream tasks.
Args:
batch (list(tuple(torch.Tensor, float))): tuples where the first element is a tensor representing the
embeddings and the second element is a float representing the label.
Returns:
tuple(torch.Tensor, torch.Tensor): The first tensor is the padded embeddings and the second tensor is the
labels.
"""
tensors = [item[0].squeeze(0) for item in batch]
floats = torch.tensor([item[1] for item in batch])
padded_sequences = pad_sequence(tensors, batch_first=True, padding_value=0)
Expand All @@ -25,20 +36,23 @@ class DownstreamDataset(Dataset):
"""

def __init__(self, data_dir: str | Path, layer_num: int):
"""Initialize the DownstreamDataset."""
self.data_dir = Path(data_dir)
self.layer_num = layer_num
assert self.data_dir.exists(), f"{self.data_dir} does not exist."
assert self.data_dir.is_dir(), f"{self.data_dir} is not a directory."
self.df = pd.read_csv(data_dir / "df.csv").dropna()

def __getitem__(self, idx: int) -> tuple[torch.Tensor, float]:
"""Return the embeddings and label for a given index."""
embeddings = torch.load(self.data_dir / f"prot_{idx}.pt", weights_only=False)["representations"][
self.layer_num
]
label = self.df.iloc[idx]["value"]
return embeddings, label

def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return self.df.shape[0]


Expand Down Expand Up @@ -66,6 +80,7 @@ class DownstreamDataModule(L.LightningDataModule):
"""

def __init__(self, data_dir: str | Path, layer_num: int, batch_size: int, num_workers: int = 8):
"""Initialize the DownstreamDataModule."""
super().__init__()
self.data_dir = Path(data_dir)
self.layer_num = layer_num
Expand All @@ -85,10 +100,13 @@ def _get_dataloader(self, dataset: DownstreamDataset, shuffle: bool = False) ->
)

def train_dataloader(self) -> DataLoader:
"""Return training dataloader."""
return self._get_dataloader(self.train, shuffle=True)

def val_dataloader(self) -> DataLoader:
"""Return validation dataloader."""
return self._get_dataloader(self.valid)

def test_dataloader(self) -> DataLoader:
"""Return test dataloader."""
return self._get_dataloader(self.test)
26 changes: 24 additions & 2 deletions smtb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,31 @@ class BaseModel(pl.LightningModule):
The `shared_step`, `forward` and `__init__` methods should be implemented in the subclasses.
"""

def __init__(self, config: Namespace):
def __init__(self, config: Namespace) -> None:
"""Initialize the model."""
super().__init__()
self.save_hyperparameters()
self.config = config

def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Training step."""
return self.shared_step(batch, "train")

def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Validation step."""
return self.shared_step(batch, "val")

def test_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Test step."""
return self.shared_step(batch, "test")

def configure_optimizers(self):
def configure_optimizers(self) -> tuple[list[optim.Optimizer], list[dict]]:
"""
Configure the optimisers and schedulers.
Returns:
tuple[list[optim.Optimizer], list[dict]]: The optimisers and schedulers.
"""
optimisers = [optim.Adam(self.parameters(), lr=self.config.lr)]
schedulers = [
{
Expand All @@ -51,6 +61,7 @@ def configure_optimizers(self):

class RegressionModel(BaseModel):
def __init__(self, config: Namespace):
"""Regression model for downstream tasks."""
super().__init__(config)
self.model = nn.Sequential(
nn.LazyLinear(config.hidden_dim),
Expand All @@ -62,9 +73,20 @@ def __init__(self, config: Namespace):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
return self.model(x).squeeze(1)

def shared_step(self, batch: tuple[torch.Tensor, torch.Tensor], name: str = "train") -> torch.Tensor:
"""
Shared step for training, validation and testing.
Args:
batch (tuple[torch.Tensor, torch.Tensor]): A tuple containing the input and output tensors.
name (str, optional): The name of the step. Defaults to "train".
Returns:
torch.Tensor: The loss value.
"""
x, y = batch
y_pred = self.forward(x).float()
y = y.float()
Expand Down
25 changes: 21 additions & 4 deletions smtb/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class BasePooling(nn.Module):
def __init__(self, config: Namespace):
def __init__(self, config: Namespace) -> None:
super().__init__()
self.config = config

Expand All @@ -23,14 +23,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class GlobalAttentionPooling(BasePooling):
def __init__(self, config: Namespace):
def __init__(self, config: Namespace) -> None:
"""
Global attention pooling layer.
Args:
config (Namespace): Configuration namespace
"""
super().__init__(config)

self.linear_key = nn.Linear(self.config.hidden_dim, 1) # Project to a single dimension
self.linear_query = nn.Linear(self.config.hidden_dim, 1)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Pooling operation.
Args:
x (torch.Tensor): Tensor of size (batch_size, seq_len, embedding_dim)
Returns:
torch.Tensor: Tensor of size (batch_size, embedding_dim)
"""
keys = self.linear_key(x)
queries = self.linear_query(x)
attention_scores = torch.matmul(queries, keys.transpose(-2, -1))
Expand All @@ -40,8 +55,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class MeanPooling(BasePooling):
def __init__(self, config: Namespace):
def __init__(self, config: Namespace) -> None:
"""Mean pooling layer."""
super().__init__(config)

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pooling operation."""
return x.mean(dim=1)
41 changes: 35 additions & 6 deletions smtb/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,34 @@

import transformers
from tokenizers import Tokenizer
from tokenizers.models import BPE, Unigram, WordPiece
from tokenizers.models import BPE, Unigram, WordPiece, Model
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordPieceTrainer
from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordPieceTrainer, Trainer
from transformers import PreTrainedTokenizerFast

TOKENIZATION_TYPES = Literal["bpe", "wordpiece", "unigram", "char"]


def _get_tokenizer(
model,
trainer,
model: Model,
trainer: Trainer,
vocab_size: int,
model_kwargs: dict | None = None,
trainer_kwargs: dict | None = None,
):
) -> tuple[Tokenizer, Trainer]:
"""
Helper function to get tokenizer and trainer objects.
Args:
model (Model): model object to be initialized
trainer (Trainer): trainer object to be initialized
vocab_size (int): How many tokens to learn
model_kwargs (dict | None): Arguments to be passed to the model
trainer_kwargs (dict | None): Arguments to be passed to the trainer
Return:
Initialized tokenizer and trainer objects
"""
if model_kwargs is None:
model_kwargs = dict(unk_token="[UNK]")
if trainer_kwargs is None:
Expand All @@ -34,6 +47,18 @@ def train_tokenizer(
output_dir: str | Path = "data/tokenization",
vocab_size: int = 5000,
) -> transformers.PreTrainedTokenizerFast:
"""
Train a tokenizer on a given dataset.
Args:
dataset (Iterable[str]): Dataset to train the tokenizer on
tokenization_type (TOKENIZER_TYPES): Type of tokenization to use
output_dir (str | Path): Directory to save the tokenizer
vocab_size (int): How many tokens to learn
Return:
Trained tokenizer
"""
if not isinstance(output_dir, Path):
output_dir = Path(output_dir)
if not output_dir.exists():
Expand Down Expand Up @@ -80,11 +105,15 @@ def train_tokenizer(
unk_token="[UNK]",
)
)
print(tokenizer)
# Train the tokenizer if it's not a char-level tokenizer that should only use aminoacids as tokens
if tokenization_type != "char":
tokenizer.train_from_iterator(iterator=dataset, trainer=trainer)

# save the tokenizer ...
tokenizer_file = str(output_dir / f"{tokenization_type}.json")
tokenizer.save(tokenizer_file)

# ... and load it into a PreTrainedTokenizerFast object (from transformers package)
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
tokenizer.add_special_tokens(
{
Expand Down
5 changes: 3 additions & 2 deletions smtb/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse
from argparse import Namespace
from pathlib import Path

import wandb
Expand All @@ -10,7 +10,8 @@
from smtb.model import RegressionModel


def train(config: argparse.Namespace):
def train(config: Namespace) -> None:
"""Train the model."""
dataset_path = Path(config.dataset_path)
seed_everything(config.seed)
logger = WandbLogger()
Expand Down

0 comments on commit 0406e4f

Please sign in to comment.