Skip to content

Commit

Permalink
Adds FFCV support to trainer + restores old def file
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneNx committed Feb 20, 2022
1 parent f32c21e commit 1ff9dc9
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 67 deletions.
33 changes: 33 additions & 0 deletions Singularity.v0.1.def
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Bootstrap: oras
From: ghcr.io/sinzlab/pytorch-singularity:v3.9-torch1.10.2-dj0.12.7.def

%labels
MAINTAINER Arne Nix <[email protected]>

%post
# install third-party libraries
# needed for vim extension in jupyter and tex export in matplotlib:
apt update && apt install -y \
texlive-latex-extra \
texlive-fonts-recommended \
texlive-base \
dvipng \
zsh \
python3-venv

python3.9 -m pip --no-cache-dir install \
checkout_code \
requests \
imageio \
scikit-image \
einops \
vit-pytorch

%environment
export SHELL=/usr/bin/zsh

%startscript
exec "$@"

%runscript
exec "$@"
81 changes: 16 additions & 65 deletions Singularity.v0.2.def
Original file line number Diff line number Diff line change
@@ -1,80 +1,31 @@
Bootstrap: docker

From: continuumio/miniconda3

%files
environment.yml
Bootstrap: oras
From: ghcr.io/sinzlab/pytorch-singularity:v3.9-torch1.10.2-dj0.12.7.def

%labels
MAINTAINER Arne Nix <[email protected]>
%post
apt update && apt install -y \
build-essential \
git \
wget \
vim \
curl \
zip \
zlib1g-dev \
unzip \
pkg-config \
libblas-dev \
liblapack-dev \
python3-tk \
python3-wheel \
graphviz \
libhdf5-dev \
python3.9 \
python3.9-dev \
python3.9-distutils \
python3-testresources \
software-properties-common \
swig \
ffmpeg \
texlive-latex-extra \
texlive-fonts-recommended \
texlive-base \
dvipng
#zsh
apt-get clean
# install third-party libraries
# needed for vim extension in jupyter and tex export in matplotlib:
apt update && apt install -y libturbojpeg \
libturbojpeg-dev \
libopencv-dev \
python3-venv \
zsh \
python3-opencv
# texlive-latex-extra \
# texlive-fonts-recommended \
# texlive-base \
# dvipng \

ln -s /usr/bin/python3.9 /usr/local/bin/python
ln -s /usr/bin/python3.9 /usr/local/bin/python3

curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
python3.9 get-pip.py
rm get-pip.py
python3.9 -m pip --no-cache-dir install --upgrade pip

python3.9 -m pip --no-cache-dir install \
blackcellmagic\
pytest \
pytest-cov \
numpy \
matplotlib \
scipy \
pandas \
jupyter \
scikit-learn \
scikit-image \
seaborn \
graphviz \
gpustat \
h5py \
gitpython \
Pillow==8.0.1 \
jupyterlab \
datajoint==0.12.7\
ipykernel \
requests \
imageio \
scikit-image \
einops \
vit-pytorch

python3.9 -m pip --no-cache-dir install \
torch==1.10.2+cu113 \
torchvision==0.11.3+cu113 \
torchaudio==0.10.2+cu113 \
-f https://download.pytorch.org/whl/cu113/torch_stable.html

%runscript
exec "$@"
Expand Down
5 changes: 3 additions & 2 deletions nntransfer/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from contextlib import nullcontext
from functools import partial

from tqdm import tqdm
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, dataloaders, model, seed, uid, cb, **kwargs):
self.task_keys = dataloaders["train"].keys()
self.optimizer, self.stop_closure, self.criterion = self.get_training_controls()
self.lr_scheduler = self.prepare_lr_schedule()
if self.use_ffcv:
if self.config.use_ffcv:
self.scaler = GradScaler()

# Potentially reset parts of the model (after loading pretrained parameters)
Expand Down Expand Up @@ -215,7 +216,7 @@ def main_loop(
shared_memory = {} # e.g. to remember where which noise was applied
model_ = self.model

forward_context = autocast() if self.config.use_ffcv else nullcontext
forward_context = autocast() if self.config.use_ffcv else nullcontext()
with forward_context:
for module in self.main_loop_modules:
model_, inputs = module.pre_forward(
Expand Down

0 comments on commit 1ff9dc9

Please sign in to comment.