From 4e2df12166e98e5cee7eec7748dd260634ee3d60 Mon Sep 17 00:00:00 2001 From: Roman Joeres Date: Sun, 29 Sep 2024 20:41:07 +0200 Subject: [PATCH] Refactoring of tests and restructuring of requirements --- .github/workflows/run_test.yaml | 12 +++++++++++- requirements.txt | 12 ++++++------ {smtb/tests => tests}/__init__.py | 0 {smtb/tests => tests}/fixtures.py | 0 {smtb/tests => tests}/test_data.py | 2 +- {smtb/tests => tests}/test_finetune.py | 2 +- {smtb/tests => tests}/test_model.py | 2 +- {smtb/tests => tests}/test_pooling.py | 2 +- {smtb/tests => tests}/test_tokenizers.py | 2 +- 9 files changed, 22 insertions(+), 12 deletions(-) rename {smtb/tests => tests}/__init__.py (100%) rename {smtb/tests => tests}/fixtures.py (100%) rename {smtb/tests => tests}/test_data.py (94%) rename {smtb/tests => tests}/test_finetune.py (89%) rename {smtb/tests => tests}/test_model.py (89%) rename {smtb/tests => tests}/test_pooling.py (94%) rename {smtb/tests => tests}/test_tokenizers.py (96%) diff --git a/.github/workflows/run_test.yaml b/.github/workflows/run_test.yaml index a0ee7bb..142bc14 100644 --- a/.github/workflows/run_test.yaml +++ b/.github/workflows/run_test.yaml @@ -1,6 +1,16 @@ name: Pytest -on: [push] +on: + # Triggers the workflow on push or pull request events but only for the main branch + push: + branches: + - main + - dev + pull_request: + branches: + - main + - dev + workflow_dispatch: # make is manually start-able jobs: build: diff --git a/requirements.txt b/requirements.txt index 0072685..0257484 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ -torch -lightning -torchmetrics +pandas rich fair-esm -wandb -tokenizers -transformers beartype +wandb +torch +lightning +torchmetrics datasets +tokenizers transformers[torch] pytest pytest-cov diff --git a/smtb/tests/__init__.py b/tests/__init__.py similarity index 100% rename from smtb/tests/__init__.py rename to tests/__init__.py diff --git a/smtb/tests/fixtures.py b/tests/fixtures.py similarity index 100% rename from smtb/tests/fixtures.py rename to tests/fixtures.py diff --git a/smtb/tests/test_data.py b/tests/test_data.py similarity index 94% rename from smtb/tests/test_data.py rename to tests/test_data.py index ed021aa..bb1a37a 100644 --- a/smtb/tests/test_data.py +++ b/tests/test_data.py @@ -3,7 +3,7 @@ import pytest import torch -from ..data import DownstreamDataModule, DownstreamDataset +from smtb.data import DownstreamDataModule, DownstreamDataset from .fixtures import mock_data_dir diff --git a/smtb/tests/test_finetune.py b/tests/test_finetune.py similarity index 89% rename from smtb/tests/test_finetune.py rename to tests/test_finetune.py index 67b3fed..b4d0517 100644 --- a/smtb/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -1,6 +1,6 @@ import os -from ..train import train +from smtb.train import train from .fixtures import mock_data_dir, sample_config diff --git a/smtb/tests/test_model.py b/tests/test_model.py similarity index 89% rename from smtb/tests/test_model.py rename to tests/test_model.py index 905a41b..d6b09fe 100644 --- a/smtb/tests/test_model.py +++ b/tests/test_model.py @@ -3,7 +3,7 @@ import pytest import torch -from ..model import RegressionModel +from smtb.model import RegressionModel from .fixtures import sample_batch_x, sample_config diff --git a/smtb/tests/test_pooling.py b/tests/test_pooling.py similarity index 94% rename from smtb/tests/test_pooling.py rename to tests/test_pooling.py index 3b6ae75..8454506 100644 --- a/smtb/tests/test_pooling.py +++ b/tests/test_pooling.py @@ -3,7 +3,7 @@ import pytest import torch -from ..model import poolings +from smtb.model import poolings from .fixtures import sample_batch_x, sample_config diff --git a/smtb/tests/test_tokenizers.py b/tests/test_tokenizers.py similarity index 96% rename from smtb/tests/test_tokenizers.py rename to tests/test_tokenizers.py index 415df67..d0d142b 100644 --- a/smtb/tests/test_tokenizers.py +++ b/tests/test_tokenizers.py @@ -4,7 +4,7 @@ from datasets import Dataset, DatasetDict from transformers import PreTrainedTokenizerFast -from ..tokenization import TOKENIZATION_TYPES, train_tokenizer +from smtb.tokenization import TOKENIZATION_TYPES, train_tokenizer def tokenization_types():