Skip to content

Commit

Permalink
updated documentation and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilsenatorov committed Aug 16, 2024
1 parent 8830bdb commit 9167693
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 97 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Pytest

on: [push, pull_request]
on: [push]

jobs:
build:
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ lightning
torchmetrics
rich
fair-esm
jsonargparse
wandb
tokenizers
transformers
Expand Down
39 changes: 36 additions & 3 deletions smtb/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@
from torch.utils.data import DataLoader, Dataset


def collate_fn(batch: list[tuple[torch.Tensor, float]]):
def collate_fn(batch: list[tuple[torch.Tensor, float]]) -> tuple[torch.Tensor, torch.Tensor]:
tensors = [item[0].squeeze(0) for item in batch]
floats = torch.tensor([item[1] for item in batch])
padded_sequences = pad_sequence(tensors, batch_first=True, padding_value=0)
return padded_sequences, floats


class DownstreamDataset(Dataset):
"""Dataset for downstream tasks. The `data_dir` is expected to have the following structure:
data_dir
├── df.csv
├── prot_0.pt
├── prot_1.pt
├── ...
└── prot_n.pt
"""

def __init__(self, data_dir: str | Path, layer_num: int):
self.data_dir = Path(data_dir)
self.layer_num = layer_num
Expand All @@ -29,24 +38,48 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, float]:
label = self.df.iloc[idx]["value"]
return embeddings, label

def __len__(self):
def __len__(self) -> int:
return self.df.shape[0]


class DownstreamDataModule(L.LightningDataModule):
"""DataModule for downstream tasks. The `data_dir` is expected to have the following structure:
data_dir
├── train
│ ├── df.csv
│ ├── prot_0.pt
│ ├── prot_1.pt
│ ├── ...
│ └── prot_n.pt
├── valid
│ ├── df.csv
│ ├── prot_0.pt
│ ├── prot_1.pt
│ ├── ...
│ └── prot_n.pt
└── test
├── df.csv
├── prot_0.pt
├── prot_1.pt
├── ...
└── prot_n.pt
"""

def __init__(self, data_dir: str | Path, layer_num: int, batch_size: int, num_workers: int = 8):
super().__init__()
self.data_dir = Path(data_dir)
self.layer_num = layer_num
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage=None):
def setup(self, stage: str | None = None):
"""Create train, val, test datasets."""
self.train = DownstreamDataset(self.data_dir / "train", self.layer_num)
self.valid = DownstreamDataset(self.data_dir / "valid", self.layer_num)
self.test = DownstreamDataset(self.data_dir / "test", self.layer_num)

def _get_dataloader(self, dataset: DownstreamDataset, shuffle: bool = False) -> torch.utils.data.DataLoader:
"""Create a DataLoader for a given dataset."""
return DataLoader(
dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=shuffle, collate_fn=collate_fn
)
Expand Down
6 changes: 5 additions & 1 deletion smtb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@


class BaseModel(pl.LightningModule):
"""Base model for downstream tasks. This class should be subclassed by specific models.
The `shared_step`, `forward` and `__init__` methods should be implemented in the subclasses.
"""

def __init__(self, config: Namespace):
super().__init__()
self.save_hyperparameters()
Expand Down Expand Up @@ -50,7 +54,7 @@ def __init__(self, config: Namespace):
super().__init__(config)
self.model = nn.Sequential(
nn.LazyLinear(config.hidden_dim),
poolings[config.pooling](config.hidden_dim),
poolings[config.pooling](config),
nn.LazyLinear(config.hidden_dim),
nn.ReLU(),
nn.Dropout(p=config.dropout),
Expand Down
17 changes: 10 additions & 7 deletions smtb/pooling.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from argparse import Namespace

import torch
import torch.nn as nn


class BasePooling(nn.Module):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, config: Namespace):
super().__init__()
self.config = config

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pools the input tensor.
Expand All @@ -20,11 +23,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class GlobalAttentionPooling(BasePooling):
def __init__(self, input_dim: int):
super(GlobalAttentionPooling, self).__init__()
def __init__(self, config: Namespace):
super().__init__(config)

self.linear_key = nn.Linear(input_dim, 1) # Project to a single dimension
self.linear_query = nn.Linear(input_dim, 1)
self.linear_key = nn.Linear(self.config.hidden_dim, 1) # Project to a single dimension
self.linear_query = nn.Linear(self.config.hidden_dim, 1)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -37,8 +40,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class MeanPooling(BasePooling):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
def __init__(self, config: Namespace):
super().__init__(config)

def forward(self, x):
return x.mean(dim=1)
37 changes: 31 additions & 6 deletions smtb/tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
import random
from pathlib import Path

import pandas as pd
import pytest
import torch


@pytest.fixture
def mock_data_dir(tmp_path):
def mock_data_dir(tmp_path: Path):
# Create a temporary directory with mock data
data_dir = tmp_path / "data"
for ds in ["train", "valid", "test"]:
Expand All @@ -19,9 +21,32 @@ def mock_data_dir(tmp_path):

# Create mock .pt files
for i in range(len(df)):
torch.save(
{"representations": {0: torch.rand((random.randint(50, 60), 320))}}, ds_data_dir / f"prot_{i}.pt"
)
print(data_dir)

# seq_len between 50 and 60, embedding_dim=320
data = {"representations": {0: torch.rand((random.randint(50, 60), 320))}}
torch.save(data, ds_data_dir / f"prot_{i}.pt")
return data_dir


@pytest.fixture
def sample_batch_x():
"""Create a sample input tensor of shape (batch_size, seq_len, embedding_dim)"""
return torch.randn(4, 10, 32) # Example: batch_size=4, seq_len=10, embedding_dim=8


@pytest.fixture
def sample_config():
"""Create a sample config object for training."""
return argparse.Namespace(
layer_num=0,
pooling="mean",
hidden_dim=32,
batch_size=2,
num_workers=0,
max_epoch=2,
dropout=0.2,
early_stopping_patience=20,
lr=0.001,
reduce_lr_patience=10,
reduce_lr_factor=0.1,
seed=42,
)
10 changes: 4 additions & 6 deletions smtb/tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path

import pytest
import torch

from ..data import DownstreamDataModule, DownstreamDataset
from .fixtures import mock_data_dir


def test_downstream_dataset(mock_data_dir):
def test_downstream_dataset(mock_data_dir: Path):
dataset = DownstreamDataset(mock_data_dir / "train", layer_num=0)

# Check the length of the dataset
Expand All @@ -18,7 +20,7 @@ def test_downstream_dataset(mock_data_dir):
assert embeddings.size(1) == 320


def test_downstream_data_module(mock_data_dir):
def test_downstream_data_module(mock_data_dir: Path):
data_module = DownstreamDataModule(mock_data_dir, layer_num=0, batch_size=2, num_workers=0)
data_module.setup()

Expand All @@ -35,7 +37,3 @@ def test_downstream_data_module(mock_data_dir):
assert labels.size(0) == 2
assert embeddings.size(2) == 320
assert 50 <= embeddings.size(1) <= 60


if __name__ == "__main__":
pytest.main()
38 changes: 1 addition & 37 deletions smtb/tests/test_finetune.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,7 @@
# parser = argparse.ArgumentParser()
# parser.add_argument("--dataset_path", type=str, required=True)
# parser.add_argument("--layer_num", type=int, required=True)
# parser.add_argument("--pooling", type=str, default="mean")
# parser.add_argument("--hidden_dim", type=int, default=512)
# parser.add_argument("--batch_size", type=int, default=1024)
# parser.add_argument("--num_workers", type=int, default=12)
# parser.add_argument("--max_epoch", type=int, default=1000)
# parser.add_argument("--dropout", type=float, default=0.2)
# parser.add_argument("--early_stopping_patience", type=int, default=20)
# parser.add_argument("--lr", type=float, default=0.001)
# parser.add_argument("--reduce_lr_patience", type=int, default=10)
# parser.add_argument("--reduce_lr_factor", type=float, default=0.1)
# parser.add_argument("--seed", type=int, default=42)

import argparse
import os

import pytest

from ..train import train
from .fixtures import mock_data_dir


@pytest.fixture
def sample_config():
return argparse.Namespace(
layer_num=0,
pooling="mean",
hidden_dim=32,
batch_size=2,
num_workers=0,
max_epoch=2,
dropout=0.2,
early_stopping_patience=20,
lr=0.001,
reduce_lr_patience=10,
reduce_lr_factor=0.1,
seed=42,
)
from .fixtures import mock_data_dir, sample_config


def test_train(sample_config, mock_data_dir):
Expand Down
22 changes: 4 additions & 18 deletions smtb/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,10 @@
import torch

from ..model import RegressionModel
from .fixtures import sample_batch_x, sample_config


@pytest.fixture
def sample_input():
# Create a sample input tensor of shape (batch_size, seq_len, embedding_dim)
return torch.randn(4, 10, 8) # Example: batch_size=4, seq_len=10, embedding_dim=8


@pytest.fixture
def sample_config():
return Namespace(hidden_dim=256, pooling="mean", dropout=0.3, lr=0.01, reduce_lr_patience=5)


def test_regression_model_forward(sample_config, sample_input):
def test_regression_model_forward(sample_config, sample_batch_x):
model = RegressionModel(sample_config)
output = model.forward(sample_input)
assert output.size(0) == 4


if __name__ == "__main__":
pytest.main()
output = model.forward(sample_batch_x)
assert output.size(0) == sample_batch_x.size(0)
22 changes: 8 additions & 14 deletions smtb/tests/test_pooling.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from argparse import Namespace

import pytest
import torch

from ..model import poolings


@pytest.fixture
def sample_input():
# Create a sample input tensor of shape (batch_size, seq_len, embedding_dim)
return torch.randn(4, 10, 8) # Example: batch_size=4, seq_len=10, embedding_dim=8
from .fixtures import sample_batch_x, sample_config


@pytest.mark.parametrize("pooling_layer", [x for x in poolings.values()])
def test_pooling(sample_input, pooling_layer):
def test_pooling(sample_batch_x, sample_config, pooling_layer):
"""Test the pooling layer."""
# Create an instance of the pooling layer
pooling = pooling_layer(input_dim=sample_input.shape[-1])
pooling = pooling_layer(sample_config)

# Test the forward method
output = pooling(sample_input)
assert output.shape == (sample_input.shape[0], sample_input.shape[-1])


if __name__ == "__main__":
pytest.main()
output = pooling(sample_batch_x)
assert output.shape == (sample_batch_x.shape[0], sample_batch_x.shape[-1])
8 changes: 8 additions & 0 deletions smtb/tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ def test_train_tokenizer(tokenization_type, sample_dataset, tmp_path):
assert tokenizer.cls_token == "[CLS]"
assert tokenizer.sep_token == "[SEP]"
assert tokenizer.unk_token == "[UNK]"


def test_char_tokenizer(sample_dataset, tmp_path):
output_dir = tmp_path / "tokenization"
tokenizer = train_tokenizer(dataset=sample_dataset, tokenization_type="char", output_dir=output_dir)
for token in tokenizer.vocab.keys():
if token[0] != "[":
assert len(token) == 1, f"Token {token} is not a single character"
Loading

0 comments on commit 9167693

Please sign in to comment.