Skip to content

Commit

Permalink
Merge pull request #45 from atong01/torchdyn-upgrade
Browse files Browse the repository at this point in the history
Update torchdyn requirement and actions versions
  • Loading branch information
kilianFatras authored Sep 10, 2023
2 parents a0b5277 + 8a40d17 commit 3fd8d15
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 71 deletions.
5 changes: 1 addition & 4 deletions .github/workflows/code-quality-main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

name: Code Quality Main

env:
SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True

on:
push:
branches: [main]
Expand All @@ -19,7 +16,7 @@ jobs:
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: "3.10"

Expand Down
19 changes: 7 additions & 12 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
name: Tests

env:
SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True

on:
push:
branches: [main]
Expand All @@ -16,24 +13,22 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
os: [ubuntu-latest, ubuntu-20.04, macos-latest, windows-latest]
python-version: ["3.8", "3.9", "3.10"]

timeout-minutes: 10

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r runner-requirements.txt
pip install pytest
pip install sh
pip install -e .
Expand All @@ -44,7 +39,7 @@ jobs:
- name: Run pytest
run: |
pytest -v
pytest -v runner
# upload code coverage report
code-coverage:
Expand All @@ -55,21 +50,21 @@ jobs:
uses: actions/checkout@v3

- name: Set up Python 3.10
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r runner-requirements.txt
pip install pytest
pip install pytest-cov[toml]
pip install sh
pip install -e .
- name: Run tests and collect coverage
run: pytest --cov src # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER
run: pytest runner --cov runner # NEEDS TO BE UPDATED WHEN CHANGING THE NAME OF "src" FOLDER

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
10 changes: 7 additions & 3 deletions examples/cifar10/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from timm import scheduler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.conditional_flow_matching import (
ExactOptimalTransportConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper

savedir = "weights/reproduced/"
Expand Down Expand Up @@ -99,7 +100,10 @@
t_span=torch.linspace(0, 1, 100).to(device),
)
grid = make_grid(
traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1),
value_range=(-1, 1),
padding=0,
nrow=10,
)

img = grid.detach().cpu() / 2 + 0.5 # unnormalize
Expand Down
48 changes: 3 additions & 45 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,55 +1,13 @@
# Note if using Conda it is recommended to install torch separately.
# For most of testing the following commands were run to set up the environment
# This was tested with torch==1.12.1
# conda create -n ti-env python=3.10
# conda activate ti-env
# pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# pip install -r requirements.txt
# --------- pytorch --------- #
torch>=1.11.0
torchvision>=0.11.0
pytorch-lightning==1.8.3
torchmetrics==0.11.0

# --------- hydra --------- #
hydra-core==1.2.0
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
hydra-submitit-launcher

# --------- loggers --------- #
wandb
# neptune-client
# mlflow
# comet-ml

# --------- others --------- #
black
isort
flake8
Flake8-pyproject # for configuration via pyproject
pyrootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)


# --------- pkg reqs -------- #
lightning-bolts
matplotlib
numpy
# sdeint
torchdyn==1.0.3 # version 1.0.4 is broken
scipy
scikit-learn
pot
scprep
scanpy
lightning-bolts
timm

# --------- notebook reqs -------- #
seaborn>=0.12.2
pandas

git+https://github.com/patrick-kidger/torchcubicspline.git
torchdyn>=1.0.5 # 1.0.4 is broken on pypi
pot
52 changes: 52 additions & 0 deletions runner-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Note if using Conda it is recommended to install torch separately.
# For most of testing the following commands were run to set up the environment
# This was tested with torch==1.12.1
# conda create -n ti-env python=3.10
# conda activate ti-env
# pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# pip install -r requirements.txt
# --------- pytorch --------- #
torch>=1.11.0,<2.0.0
torchvision>=0.11.0
pytorch-lightning==1.8.3
torchmetrics==0.11.0

# --------- hydra --------- #
hydra-core==1.2.0
hydra-colorlog==1.2.0
hydra-optuna-sweeper==1.2.0
# hydra-submitit-launcher

# --------- loggers --------- #
wandb
# neptune-client
# mlflow
# comet-ml

# --------- others --------- #
black
isort
flake8
Flake8-pyproject # for configuration via pyproject
pyrootutils # standardizing the project root setup
pre-commit # hooks for applying linters on commit
rich # beautiful text formatting in terminal
pytest # tests
# sh # for running bash commands in some tests (linux/macos only)


# --------- pkg reqs -------- #
lightning-bolts
matplotlib
numpy
scipy
scikit-learn
scprep
scanpy
timm
torchdyn>=1.0.5 # 1.0.4 is broken on pypi
pot

# --------- notebook reqs -------- #
seaborn>=0.12.2
pandas
2 changes: 1 addition & 1 deletion runner/src/models/components/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self, net, augmentation_list: nn.ModuleList, dim):
self.dim = dim
self.augmentation_list = augmentation_list

def forward(self, t, state, augmented_input=True):
def forward(self, t, state, augmented_input=True, *args, **kwargs):
n_aug = len(self.augmentation_list)

class SharedContext:
Expand Down
6 changes: 3 additions & 3 deletions runner/src/models/components/simple_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, dim: int, *args, **kwargs):
def energy(self, x):
return self.model(x)

def forward(self, t, x):
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
if t.dim() < 2:
t = t.repeat(x.shape[0])[:, None]
Expand All @@ -64,7 +64,7 @@ class TimeInvariantVelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim, target_size=dim, *args, **kwargs)

def forward(self, t, x):
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
del t
return self.model(x)
Expand All @@ -74,7 +74,7 @@ class VelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)

def forward(self, t, x):
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
Expand Down
1 change: 0 additions & 1 deletion runner/src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from src.utils.pylogger import get_pylogger
from src.utils.rich_utils import enforce_tags, print_config_tree

from src.utils.utils import (
close_loggers,
extras,
Expand Down
1 change: 0 additions & 1 deletion runner/src/utils/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning.utilities import rank_zero_only
from rich.prompt import Prompt

from src.utils import pylogger

log = pylogger.get_pylogger(__name__)
Expand Down
1 change: 0 additions & 1 deletion runner/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pytorch_lightning import Callback
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only

from src.utils import pylogger, rich_utils

log = pylogger.get_pylogger(__name__)
Expand Down

0 comments on commit 3fd8d15

Please sign in to comment.