Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SwaV self-supervision method #197

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10"]
steps:
#----------------------------------------------
# check-out repo and set-up python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10"]
steps:
#----------------------------------------------
# check-out repo and set-up python
Expand Down
44 changes: 22 additions & 22 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@ channels:
- pytorch
- conda-forge
- defaults
- nvidia
dependencies:
- python=3.9
- pip=22.3.1
- numpy=1.21.*
- python=3.10
- pip==23.1.1
- faiss-cpu=1.7.2
- mlflow=1.27.0
- pandas=1.4.3
- pillow=9.1.1
- pytorch=1.12.1
- torchvision=0.13.1
- cudatoolkit=11.6
- parameterized=0.8.1
- mlflow=2.3.0
- pytorch=2.0
- torchvision
- pytorch-cuda=11.8
- parameterized=0.9.0
- pytorch-lightning=2.0.2
- torchmetrics=0.11.4
- pip:
- pytorch-lightning==1.9.0
- torchmetrics==0.11.0
- onnx==1.12.0
- onnxruntime-gpu==1.11.1
- lightly==1.4.3
- opencv-python==4.7.0.72
- onnx==1.13.1
- onnxruntime-gpu==1.14.1
- albumentations==1.3.0
- hydra-core==1.2.0
- opencv-python==4.6.0.*
- ranx==0.2.8
- timm==0.6.12
- mmdet==2.26.0
- pymysql==1.0.2
# install mmcv-full via
# `pip install mmcv-full==1.6.2 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12/index.html`
- hydra-core==1.3.2
- ranx==0.3.7
- timm==0.6.13
- pymysql==1.0.3
- imagesize
- mmdet==3.0.0
# install mmcv via
# `pip install mmcv==2.0.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.0/index.html`
4,681 changes: 2,635 additions & 2,046 deletions poetry.lock

Large diffs are not rendered by default.

31 changes: 15 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,26 @@ albumentations = "1.3.0"
faiss-cpu = "1.7.2"
hydra-core = "~1.2"
imagesize = "1.4.1"
lightly = "~1.4"
mmdet = "3.0.0"
numpy = [
{version = "~1.22", python = ">=3.8,<3.11"},
{version = "~1.21", python = ">=3.7.1,<3.8"}
{version = "~1.22", python = ">=3.8,<3.11"}
]
onnx = "~1.12"
onnxruntime-gpu = "~1.12"
opencv-python = "~4.6"
onnx = "~1.13"
onnxruntime-gpu = "~1.14"
opencv-python = "~4.7"
pandas = [
{version = "~1.4", python = ">=3.8,<3.11"},
{version = "~1.3", python = ">=3.7.1,<3.8"}
{version = "~1.4", python = ">=3.8,<3.11"}
]
parameterized = "~0.8"
pillow = "~9.1"
python = ">=3.7,<3.11"
pytorch-lightning = "1.9.0"
ranx = "~0.2"
torch = "1.12.1"
torchvision = "0.13.1"
torchmetrics = "~0.11"
parameterized = "~0.9"
python = ">=3.8,<3.11"
pytorch-lightning = "~2.0"
ranx = "~0.3"
seaborn = "^0.12.2"
timm = "~0.6"
mmdet = "2.26.0"
torch = "~2.0"
torchvision = "~0.15"
torchmetrics = "~0.11"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run_model(metric_params: List[MetricParams]):
trainer = Trainer(
accelerator="cpu",
strategy="ddp",
num_processes=2,
num_nodes=1,
max_epochs=EPOCH,
)

Expand Down
3 changes: 1 addition & 2 deletions torchok/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pytorch_lightning.callbacks import (BackboneFinetuning, DeviceStatsMonitor, EarlyStopping,
GradientAccumulationScheduler, LearningRateMonitor, ModelCheckpoint,
ModelPruning, ModelSummary, QuantizationAwareTraining, RichModelSummary,
ModelPruning, ModelSummary, RichModelSummary,
RichProgressBar, StochasticWeightAveraging, Timer, TQDMProgressBar)

import torchok.callbacks.finalize_logger
Expand All @@ -21,4 +21,3 @@
CALLBACKS.register_class(StochasticWeightAveraging)
CALLBACKS.register_class(Timer)
CALLBACKS.register_class(TQDMProgressBar)
CALLBACKS.register_class(QuantizationAwareTraining)
3 changes: 2 additions & 1 deletion torchok/callbacks/finalize_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ class FinalizeLogger(Callback):
def on_exception(self, trainer, pl_module, outputs):
# Need to save checkpoints for every exception not only isinstance(outputs, KeyboardInterrupt)
status = 'KILLED' if type(outputs) == KeyboardInterrupt else 'FAILED'
trainer.logger.finalize(status)
if trainer.logger is not None:
trainer.logger.finalize(status)
47 changes: 21 additions & 26 deletions torchok/constructor/config_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,17 +134,12 @@ class TaskParams:
# Trainer parameters
@dataclass
class TrainerParams:
enable_checkpointing: bool = True
default_root_dir: Optional[str] = None
gradient_clip_val: Optional[float] = None
gradient_clip_algorithm: Optional[str] = None
accelerator: str = "auto"
strategy: str = "auto"
devices: str = "auto"
num_nodes: int = 1
enable_progress_bar: bool = True
overfit_batches: Any = 0.0 # Union[int, float]
track_grad_norm: Any = -1 # Union[int, float, str]
check_val_every_n_epoch: Optional[int] = 1
precision: Any = 32 # Literal[64, 32, 16, "64", "32", "16", "bf16"]
fast_dev_run: Any = False # Union[int, bool]
accumulate_grad_batches: Optional[Any] = None # Optional[Union[int, Dict[int, int]]]
max_epochs: Optional[int] = None
min_epochs: Optional[int] = None
max_steps: int = -1
Expand All @@ -154,26 +149,26 @@ class TrainerParams:
limit_val_batches: Optional[Any] = None # Optional[Union[int, float]]
limit_test_batches: Optional[Any] = None # Optional[Union[int, float]]
limit_predict_batches: Optional[Any] = None # Optional[Union[int, float]]
overfit_batches: Any = 0.0 # Union[int, float]
val_check_interval: Optional[Any] = None # Optional[Union[int, float]]
log_every_n_steps: int = 50
accelerator: Optional[str] = None
strategy: Optional[str] = None
sync_batchnorm: bool = False
precision: Any = 32 # Literal[64, 32, 16, "64", "32", "16", "bf16"]
enable_model_summary: bool = True
num_sanity_val_steps: int = 2
resume_from_checkpoint: Optional[str] = None
profiler: Optional[str] = None
check_val_every_n_epoch: Optional[int] = 1
num_sanity_val_steps: Optional[int] = None
log_every_n_steps: Optional[int] = None
enable_checkpointing: Optional[bool] = None
enable_progress_bar: Optional[bool] = None
enable_model_summary: Optional[bool] = None
accumulate_grad_batches: int = 1
gradient_clip_val: Optional[float] = None
gradient_clip_algorithm: Optional[str] = None
deterministic: Optional[bool] = None
benchmark: Optional[bool] = None
deterministic: Optional[bool] = None # Optional[Union[bool, _LITERAL_WARN]]
reload_dataloaders_every_n_epochs: int = 0
auto_lr_find: bool = False # Union[bool, str]
replace_sampler_ddp: bool = True
detect_anomaly: bool = False
auto_scale_batch_size: bool = False # Union[str, bool]
move_metrics_to_cpu: bool = False
multiple_trainloader_mode: str = "max_size_cycle"
inference_mode: bool = True
use_distributed_sampler: bool = True
profiler: Optional[str] = None
detect_anomaly: bool = False
barebones: bool = False
sync_batchnorm: bool = False
reload_dataloaders_every_n_epochs: int = 0


# Logger
Expand Down
4 changes: 1 addition & 3 deletions torchok/constructor/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,8 @@ def _prepare_transforms_recursively(transforms: ListConfig) -> List[Union[A.Comp
transform_name = transform_info.name
transform_params = transform_info.get('params', dict())

if transform_name in ['Compose', 'OneOf', 'SomeOf', 'PerChannel', 'Sequential']:
if 'transforms' in transform_params:
transform = Constructor._prepare_base_compose(transform_name, **transform_params)
elif transform_name == 'OneOrOther':
raise ValueError('OneOrOther composition is currently not supported')
else:
transform = TRANSFORMS.get(transform_name)(**transform_params)

Expand Down
2 changes: 1 addition & 1 deletion torchok/constructor/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.cloud_io import load
from lightning_fabric.utilities.cloud_io import _load as load


def load_state_dict(checkpoint_path: str, map_location: Optional[Union[str, Callable, torch.device]] = 'cpu'):
Expand Down
2 changes: 1 addition & 1 deletion torchok/data/datasets/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torchok.data.datasets.examples.cifar10
import torchok.data.datasets.examples.cifar
import torchok.data.datasets.examples.sop
import torchok.data.datasets.examples.triplet_sop
import torchok.data.datasets.examples.sweet_pepper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,27 @@ def _download(self) -> None:
else:
download_and_extract_archive(self.url, self.data_folder.as_posix(),
filename=self.filename, md5=self.tgz_md5)


@DATASETS.register_class
class CIFAR100(CIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

This is a subclass of the `CIFAR100` Dataset.
"""
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]

test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
1 change: 1 addition & 0 deletions torchok/data/datasets/representation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import torchok.data.datasets.representation.validation
import torchok.data.datasets.representation.unsupervised_contrastive_dataset
import torchok.data.datasets.representation.swav
34 changes: 34 additions & 0 deletions torchok/data/datasets/representation/swav.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections import defaultdict
from typing import Any, List, Dict

import torch
from torch.utils.data._utils.collate import default_collate

from torchok.constructor import DATASETS
from torchok.data.datasets.classification.classification import ImageClassificationDataset


@DATASETS.register_class
class SwaVDataset(ImageClassificationDataset):

def __getitem__(self, idx: int) -> List[Dict[str, Any]]:
samples = self.get_raw(idx)
new_samples = []
for i in range(len(samples['image'])):
view = {k: v[i] for k, v in samples.items()}
view = self._apply_transform(self.transform, view)
view['image'] = view['image'].type(torch.__dict__[self.input_dtype])
new_samples.append(view)

return new_samples

def collate_fn(self, batch: List[List[Dict[str, Any]]]) -> Dict[str, List[torch.Tensor]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Describe what is on input and on output

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

num_views = len(batch[0])
new_batch = defaultdict(list)

for j in range(num_views):
cross_batch_view = default_collate([sample_views[j] for sample_views in batch])
for k, v in cross_batch_view.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, rename k and v within this class to easier read the code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed

new_batch[k].append(v)

return new_batch
1 change: 1 addition & 0 deletions torchok/data/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchok.constructor import TRANSFORMS
from torchok.data.transforms import spatial
from torchok.data.transforms import pixelwise
from torchok.data.transforms import swav


TRANSFORMS.register_class(Compose)
Expand Down
Loading