diff --git a/smtb/data.py b/smtb/data.py index d80f604..743c199 100644 --- a/smtb/data.py +++ b/smtb/data.py @@ -8,6 +8,17 @@ def collate_fn(batch: list[tuple[torch.Tensor, float]]) -> tuple[torch.Tensor, torch.Tensor]: + """ + Collate function for downstream tasks. + + Args: + batch (list(tuple(torch.Tensor, float))): tuples where the first element is a tensor representing the + embeddings and the second element is a float representing the label. + + Returns: + tuple(torch.Tensor, torch.Tensor): The first tensor is the padded embeddings and the second tensor is the + labels. + """ tensors = [item[0].squeeze(0) for item in batch] floats = torch.tensor([item[1] for item in batch]) padded_sequences = pad_sequence(tensors, batch_first=True, padding_value=0) @@ -25,6 +36,7 @@ class DownstreamDataset(Dataset): """ def __init__(self, data_dir: str | Path, layer_num: int): + """Initialize the DownstreamDataset.""" self.data_dir = Path(data_dir) self.layer_num = layer_num assert self.data_dir.exists(), f"{self.data_dir} does not exist." @@ -32,6 +44,7 @@ def __init__(self, data_dir: str | Path, layer_num: int): self.df = pd.read_csv(data_dir / "df.csv").dropna() def __getitem__(self, idx: int) -> tuple[torch.Tensor, float]: + """Return the embeddings and label for a given index.""" embeddings = torch.load(self.data_dir / f"prot_{idx}.pt", weights_only=False)["representations"][ self.layer_num ] @@ -39,6 +52,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, float]: return embeddings, label def __len__(self) -> int: + """Return the number of samples in the dataset.""" return self.df.shape[0] @@ -66,6 +80,7 @@ class DownstreamDataModule(L.LightningDataModule): """ def __init__(self, data_dir: str | Path, layer_num: int, batch_size: int, num_workers: int = 8): + """Initialize the DownstreamDataModule.""" super().__init__() self.data_dir = Path(data_dir) self.layer_num = layer_num @@ -85,10 +100,13 @@ def _get_dataloader(self, dataset: DownstreamDataset, shuffle: bool = False) -> ) def train_dataloader(self) -> DataLoader: + """Return training dataloader.""" return self._get_dataloader(self.train, shuffle=True) def val_dataloader(self) -> DataLoader: + """Return validation dataloader.""" return self._get_dataloader(self.valid) def test_dataloader(self) -> DataLoader: + """Return test dataloader.""" return self._get_dataloader(self.test) diff --git a/smtb/model.py b/smtb/model.py index d7483f9..7d0c3ef 100644 --- a/smtb/model.py +++ b/smtb/model.py @@ -18,21 +18,31 @@ class BaseModel(pl.LightningModule): The `shared_step`, `forward` and `__init__` methods should be implemented in the subclasses. """ - def __init__(self, config: Namespace): + def __init__(self, config: Namespace) -> None: + """Initialize the model.""" super().__init__() self.save_hyperparameters() self.config = config def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Training step.""" return self.shared_step(batch, "train") def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Validation step.""" return self.shared_step(batch, "val") def test_step(self, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Test step.""" return self.shared_step(batch, "test") - def configure_optimizers(self): + def configure_optimizers(self) -> tuple[list[optim.Optimizer], list[dict]]: + """ + Configure the optimisers and schedulers. + + Returns: + tuple[list[optim.Optimizer], list[dict]]: The optimisers and schedulers. + """ optimisers = [optim.Adam(self.parameters(), lr=self.config.lr)] schedulers = [ { @@ -51,6 +61,7 @@ def configure_optimizers(self): class RegressionModel(BaseModel): def __init__(self, config: Namespace): + """Regression model for downstream tasks.""" super().__init__(config) self.model = nn.Sequential( nn.LazyLinear(config.hidden_dim), @@ -62,9 +73,20 @@ def __init__(self, config: Namespace): ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" return self.model(x).squeeze(1) def shared_step(self, batch: tuple[torch.Tensor, torch.Tensor], name: str = "train") -> torch.Tensor: + """ + Shared step for training, validation and testing. + + Args: + batch (tuple[torch.Tensor, torch.Tensor]): A tuple containing the input and output tensors. + name (str, optional): The name of the step. Defaults to "train". + + Returns: + torch.Tensor: The loss value. + """ x, y = batch y_pred = self.forward(x).float() y = y.float() diff --git a/smtb/pooling.py b/smtb/pooling.py index a462290..c557bca 100644 --- a/smtb/pooling.py +++ b/smtb/pooling.py @@ -5,7 +5,7 @@ class BasePooling(nn.Module): - def __init__(self, config: Namespace): + def __init__(self, config: Namespace) -> None: super().__init__() self.config = config @@ -23,7 +23,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GlobalAttentionPooling(BasePooling): - def __init__(self, config: Namespace): + def __init__(self, config: Namespace) -> None: + """ + Global attention pooling layer. + + Args: + config (Namespace): Configuration namespace + """ super().__init__(config) self.linear_key = nn.Linear(self.config.hidden_dim, 1) # Project to a single dimension @@ -31,6 +37,15 @@ def __init__(self, config: Namespace): self.softmax = nn.Softmax(dim=-1) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Pooling operation. + + Args: + x (torch.Tensor): Tensor of size (batch_size, seq_len, embedding_dim) + + Returns: + torch.Tensor: Tensor of size (batch_size, embedding_dim) + """ keys = self.linear_key(x) queries = self.linear_query(x) attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) @@ -40,8 +55,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MeanPooling(BasePooling): - def __init__(self, config: Namespace): + def __init__(self, config: Namespace) -> None: + """Mean pooling layer.""" super().__init__(config) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Pooling operation.""" return x.mean(dim=1) diff --git a/smtb/tokenization.py b/smtb/tokenization.py index a4af5b0..39af700 100644 --- a/smtb/tokenization.py +++ b/smtb/tokenization.py @@ -3,21 +3,34 @@ import transformers from tokenizers import Tokenizer -from tokenizers.models import BPE, Unigram, WordPiece +from tokenizers.models import BPE, Unigram, WordPiece, Model from tokenizers.pre_tokenizers import Whitespace -from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordPieceTrainer +from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordPieceTrainer, Trainer from transformers import PreTrainedTokenizerFast TOKENIZATION_TYPES = Literal["bpe", "wordpiece", "unigram", "char"] def _get_tokenizer( - model, - trainer, + model: Model, + trainer: Trainer, vocab_size: int, model_kwargs: dict | None = None, trainer_kwargs: dict | None = None, -): +) -> tuple[Tokenizer, Trainer]: + """ + Helper function to get tokenizer and trainer objects. + + Args: + model (Model): model object to be initialized + trainer (Trainer): trainer object to be initialized + vocab_size (int): How many tokens to learn + model_kwargs (dict | None): Arguments to be passed to the model + trainer_kwargs (dict | None): Arguments to be passed to the trainer + + Return: + Initialized tokenizer and trainer objects + """ if model_kwargs is None: model_kwargs = dict(unk_token="[UNK]") if trainer_kwargs is None: @@ -34,6 +47,18 @@ def train_tokenizer( output_dir: str | Path = "data/tokenization", vocab_size: int = 5000, ) -> transformers.PreTrainedTokenizerFast: + """ + Train a tokenizer on a given dataset. + + Args: + dataset (Iterable[str]): Dataset to train the tokenizer on + tokenization_type (TOKENIZER_TYPES): Type of tokenization to use + output_dir (str | Path): Directory to save the tokenizer + vocab_size (int): How many tokens to learn + + Return: + Trained tokenizer + """ if not isinstance(output_dir, Path): output_dir = Path(output_dir) if not output_dir.exists(): @@ -80,11 +105,15 @@ def train_tokenizer( unk_token="[UNK]", ) ) - print(tokenizer) + # Train the tokenizer if it's not a char-level tokenizer that should only use aminoacids as tokens if tokenization_type != "char": tokenizer.train_from_iterator(iterator=dataset, trainer=trainer) + + # save the tokenizer ... tokenizer_file = str(output_dir / f"{tokenization_type}.json") tokenizer.save(tokenizer_file) + + # ... and load it into a PreTrainedTokenizerFast object (from transformers package) tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) tokenizer.add_special_tokens( { diff --git a/smtb/train.py b/smtb/train.py index 40aa885..66eda76 100644 --- a/smtb/train.py +++ b/smtb/train.py @@ -1,4 +1,4 @@ -import argparse +from argparse import Namespace from pathlib import Path import wandb @@ -10,7 +10,8 @@ from smtb.model import RegressionModel -def train(config: argparse.Namespace): +def train(config: Namespace) -> None: + """Train the model.""" dataset_path = Path(config.dataset_path) seed_everything(config.seed) logger = WandbLogger()