diff --git a/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml b/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml index 011aaf920..1bdc74c1c 100644 --- a/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml +++ b/configs/config/benchmark/linear_image_classification/imagenet1k/byol_transfer_in1k_linear.yaml @@ -22,6 +22,7 @@ config: TRANSFORMS: - name: RandomResizedCrop size: 224 + interpolation: 3 - name: RandomHorizontalFlip - name: ToTensor - name: Normalize @@ -38,6 +39,7 @@ config: TRANSFORMS: - name: Resize size: 256 + interpolation: 3 - name: CenterCrop size: 224 - name: ToTensor @@ -82,7 +84,7 @@ config: PARAMS_FILE: "specify the model weights" STATE_DICT_KEY_NAME: classy_state_dict SYNC_BN_CONFIG: - CONVERT_BN_TO_SYNC_BN: True + CONVERT_BN_TO_SYNC_BN: False SYNC_BN_TYPE: apex GROUP_SIZE: 8 LOSS: @@ -93,22 +95,29 @@ config: name: sgd momentum: 0.9 num_epochs: 80 + weight_decay: 0 nesterov: True regularize_bn: False regularize_bias: True param_schedulers: lr: auto_lr_scaling: - auto_scale: true - base_value: 0.4 + # if set to True, learning rate will be scaled. + auto_scale: True + # base learning rate value that will be scaled. + base_value: 0.2 + # batch size for which the base learning rate is specified. The current batch size + # is used to determine how to scale the base learning rate value. + # scaled_lr = ((batchsize_per_gpu * world_size) * base_value ) / base_lr_batch_size base_lr_batch_size: 256 - name: multistep - values: [0.4, 0.3, 0.2, 0.1, 0.05] - milestones: [16, 32, 48, 64] - update_interval: epoch + # scaling_type can be set to "sqrt" to reduce the impact of scaling on the base value + scaling_type: "linear" + name: constant + update_interval: "epoch" + value: 0.2 DISTRIBUTED: BACKEND: nccl - NUM_NODES: 8 + NUM_NODES: 4 NUM_PROC_PER_NODE: 8 INIT_METHOD: tcp RUN_ID: auto diff --git a/configs/config/pretrain/byol/byol_1node_resnet.yaml b/configs/config/pretrain/byol/byol_1node_resnet.yaml deleted file mode 100644 index 5da2d114d..000000000 --- a/configs/config/pretrain/byol/byol_1node_resnet.yaml +++ /dev/null @@ -1,113 +0,0 @@ -# @package _global_ -config: - VERBOSE: False - LOG_FREQUENCY: 10 - TEST_ONLY: False - TEST_MODEL: False - SEED_VALUE: 0 - MULTI_PROCESSING_METHOD: forkserver - HOOKS: - PERF_STATS: - MONITOR_PERF_STATS: True - ROLLING_BTIME_FREQ: 313 - TENSORBOARD_SETUP: - USE_TENSORBOARD: True - EXPERIMENT_LOG_DIR: "byol_reference" - LOG_PARAMS: False - FLUSH_EVERY_N_MIN: 20 - DATA: - NUM_DATALOADER_WORKERS: 5 - TRAIN: - DATA_SOURCES: [disk_folder] - DATASET_NAMES: [imagenet1k_folder] - BATCHSIZE_PER_REPLICA: 32 - LABEL_TYPE: sample_index # just an implementation detail. Label isn't used - TRANSFORMS: - - name: ImgReplicatePil - num_times: 2 - - name: RandomResizedCrop - size: 128 - - name: RandomHorizontalFlip - p: 0.5 - - name: ImgPilColorDistortion - strength: 0.5 - - name: ImgPilMultiCropRandomApply - transforms: - - name: ImgPilGaussianBlur - p: 1.0 - radius_min: 0.1 - radius_max: 2.0 - prob: [ 1.0, 0.1 ] - - name: ImgPilMultiCropRandomApply - transforms: - - name: ImgPilRandomSolarize - p: 1.0 - prob: [ 0.0, 0.2 ] - - name: ToTensor - - name: Normalize - mean: [ 0.485, 0.456, 0.406 ] - std: [ 0.229, 0.224, 0.225 ] - COLLATE_FUNCTION: simclr_collator - MMAP_MODE: True - COPY_TO_LOCAL_DISK: False - COPY_DESTINATION_DIR: /tmp/imagenet1k/ - DROP_LAST: True - TRAINER: - TRAIN_STEP_NAME: standard_train_step - METERS: - name: "" - MODEL: - TRUNK: - NAME: resnet - TRUNK_PARAMS: - RESNETS: - DEPTH: 50 - ZERO_INIT_RESIDUAL: True - HEAD: - PARAMS: [ - ["mlp", {"dims": [2048, 4096], "use_relu": True, "use_bn": True}], - ["mlp", {"dims": [4096, 256]}], - ["mlp", {"dims": [256, 4096], "use_relu": True, "use_bn": True}], - ["mlp", {"dims": [4096, 256]}], - ] - SYNC_BN_CONFIG: - CONVERT_BN_TO_SYNC_BN: True - SYNC_BN_TYPE: pytorch - AMP_PARAMS: - USE_AMP: False - LOSS: - name: byol_loss - byol_loss: - embedding_dim: 256 - momentum: 0.999 - OPTIMIZER: - name: sgd - use_larc: True - larc_config: - clip: False - trust_coefficient: 0.001 - eps: 0.00000001 - weight_decay: 0.0001 - momentum: 0.9 - nesterov: False - num_epochs: 200 - regularize_bn: False - regularize_bias: False - param_schedulers: - lr: - name: multistep - values: [0.03, 0.003, 0.0003] - milestones: [120, 160] - update_interval: epoch - DISTRIBUTED: - BACKEND: nccl - NUM_NODES: 1 - NUM_PROC_PER_NODE: 8 - INIT_METHOD: tcp - RUN_ID: auto - MACHINE: - DEVICE: gpu - CHECKPOINT: - AUTO_RESUME: True - CHECKPOINT_FREQUENCY: 5 - CHECKPOINT_ITER_FREQUENCY: -1 # set this variable to checkpoint every few iterations diff --git a/configs/config/pretrain/byol/byol_8node_resnet.yaml b/configs/config/pretrain/byol/byol_8node_resnet.yaml index e23822b91..ab4665b79 100644 --- a/configs/config/pretrain/byol/byol_8node_resnet.yaml +++ b/configs/config/pretrain/byol/byol_8node_resnet.yaml @@ -67,7 +67,7 @@ config: RESNETS: DEPTH: 50 ZERO_INIT_RESIDUAL: True - HEAD: + HEAD: PARAMS: [ ["mlp", {"dims": [2048, 4096, 256], "use_relu": True, "use_bn": True}], ["mlp", {"dims": [256, 4096, 256], "use_relu": True, "use_bn": True}] @@ -82,15 +82,16 @@ config: byol_loss: embedding_dim: 256 momentum: 0.99 - OPTIMIZER: # from official BYOL implementation, deepmind-research/byol/configs/byol.py + OPTIMIZER: name: lars - trust_coefficient: 0.001 + eta: 0.001 weight_decay: 1.0e-6 momentum: 0.9 nesterov: False num_epochs: 300 regularize_bn: False - regularize_bias: True + regularize_bias: False + exclude_bias_and_norm: True param_schedulers: lr: auto_lr_scaling: diff --git a/configs/config/quick_1gpu_resnet50_byol.yaml b/configs/config/quick_1gpu_resnet50_byol.yaml deleted file mode 100644 index 61da27a34..000000000 --- a/configs/config/quick_1gpu_resnet50_byol.yaml +++ /dev/null @@ -1,135 +0,0 @@ -# @package _global_ -config: - VERBOSE: False - LOG_FREQUENCY: 1 - TEST_ONLY: False - TEST_MODEL: False - SEED_VALUE: 0 - MULTI_PROCESSING_METHOD: forkserver - MONITOR_PERF_STATS: True - PERF_STAT_FREQUENCY: 10 - ROLLING_BTIME_FREQ: 5 - HOOKS: - TENSORBOARD_SETUP: - USE_TENSORBOARD: True - EXPERIMENT_LOG_DIR: "byol_quick" - LOG_PARAMS: False - FLUSH_EVERY_N_MIN: 20 - DATA: - NUM_DATALOADER_WORKERS: 5 - TRAIN: - DATA_SOURCES: [disk_folder] - DATASET_NAMES: [imagenet1k_folder] - BATCHSIZE_PER_REPLICA: 128 - LABEL_TYPE: sample_index # just an implementation detail. Label isn't used - TRANSFORMS: - - name: ImgReplicatePil - num_times: 2 - - name: RandomResizedCrop - size: 128 - - name: RandomHorizontalFlip - p: 0.5 - - name: ImgPilColorDistortion - strength: 0.5 - - name: ImgPilMultiCropRandomApply - transforms: - - name: ImgPilGaussianBlur - p: 1.0 - radius_min: 0.1 - radius_max: 2.0 - prob: [ 1.0, 0.1 ] - - name: ImgPilMultiCropRandomApply - transforms: - - name: ImgPilRandomSolarize - p: 1.0 - prob: [ 0.0, 0.2 ] - - name: ToTensor - - name: Normalize - mean: [ 0.485, 0.456, 0.406 ] - std: [ 0.229, 0.224, 0.225 ] - COLLATE_FUNCTION: simclr_collator - MMAP_MODE: True - COPY_TO_LOCAL_DISK: False - DROP_LAST: True - COPY_DESTINATION_DIR: "/tmp/imagenet1k" - TRAINER: - TRAIN_STEP_NAME: standard_train_step - METERS: - name: "" - MODEL: - TRUNK: - NAME: resnet - RESNETS: - DEPTH: 50 - ZERO_INIT_RESIDUAL: True - HEAD: - PARAMS: [ - ["mlp", {"dims": [2048, 4096], "use_relu": True, "use_bn": True}], - ["mlp", {"dims": [4096, 256]}], - ["mlp", {"dims": [256, 4096], "use_relu": True, "use_bn": True}], - ["mlp", {"dims": [4096, 256]}], - ] - SYNC_BN_CONFIG: - CONVERT_BN_TO_SYNC_BN: True - SYNC_BN_TYPE: pytorch - AMP_PARAMS: - USE_AMP: False - LOSS: - name: byol_loss - byol_loss: - embedding_dim: 256 - momentum: 0.999 - OPTIMIZER: - name: sgd - use_larc: True - larc_config: - clip: False - trust_coefficient: 0.001 - eps: 0.00000001 - weight_decay: 0.000001 - momentum: 0.9 - nesterov: False - num_epochs: 500 - regularize_bn: False - regularize_bias: True - head_optimizer_params: - use_different_lr: False - use_different_wd: False - param_schedulers: - lr: - auto_lr_scaling: - auto_scale: false - base_value: 0.3 - base_lr_batch_size: 256 - name: composite - schedulers: - - name: linear - start_value: 0.6 - end_value: 4.8 - - name: cosine_warm_restart - start_value: 4.8 - end_value: 0.0048 - # wave_type: half - # restart_interval_length: 0.5 - wave_type: full - is_adaptive: True - restart_interval_length: 0.334 - interval_scaling: [rescaled, rescaled] - update_interval: step - lengths: [0.1, 0.9] # 100ep - DISTRIBUTED: - BACKEND: nccl - NUM_NODES: 1 - NUM_PROC_PER_NODE: 1 - INIT_METHOD: tcp - RUN_ID: auto - MACHINE: - DEVICE: gpu - CHECKPOINT: - DIR: "." - AUTO_RESUME: False - CHECKPOINT_FREQUENCY: 1 - OVERWRITE_EXISTING: true - - TENSORBOARD_SETUP: - USE_TENSORBOARD: true diff --git a/configs/config/quick_1gpu_resnet50_simclr.yaml b/configs/config/quick_1gpu_resnet50_simclr.yaml deleted file mode 100644 index 1761064ad..000000000 --- a/configs/config/quick_1gpu_resnet50_simclr.yaml +++ /dev/null @@ -1,121 +0,0 @@ -# @package _global_ -config: - VERBOSE: False - LOG_FREQUENCY: 1 - TEST_ONLY: False - TEST_MODEL: False - SEED_VALUE: 0 - MULTI_PROCESSING_METHOD: forkserver - MONITOR_PERF_STATS: True - PERF_STAT_FREQUENCY: 10 - ROLLING_BTIME_FREQ: 5 - DATA: - NUM_DATALOADER_WORKERS: 5 - TRAIN: - DATA_SOURCES: [disk_filelist] - DATASET_NAMES: [imagenet1k_filelist] - BATCHSIZE_PER_REPLICA: 32 - LABEL_TYPE: sample_index # just an implementation detail. Label isn't used - TRANSFORMS: - - name: ImgReplicatePil - num_times: 2 - - name: RandomResizedCrop - size: 224 - - name: RandomHorizontalFlip - p: 0.5 - - name: ImgPilColorDistortion - strength: 1.0 - - name: ImgPilGaussianBlur - p: 0.5 - radius_min: 0.1 - radius_max: 2.0 - - name: ToTensor - - name: Normalize - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - COLLATE_FUNCTION: simclr_collator - MMAP_MODE: True - COPY_TO_LOCAL_DISK: False - DATA_LIMIT: 500 - DROP_LAST: True - COPY_DESTINATION_DIR: "/tmp/imagenet1k" - TRAINER: - TRAIN_STEP_NAME: standard_train_step - METERS: - name: "" - MODEL: - TRUNK: - NAME: resnet - RESNETS: - DEPTH: 50 - HEAD: - PARAMS: [ - ["mlp", {"dims": [2048, 2048], "use_relu": True}], - ["mlp", {"dims": [2048, 128]}], - ] - SYNC_BN_CONFIG: - CONVERT_BN_TO_SYNC_BN: True - SYNC_BN_TYPE: pytorch - AMP_PARAMS: - USE_AMP: False - AMP_ARGS: {"opt_level": "O3", "keep_batchnorm_fp32": True, "master_weights": True, "loss_scale": "dynamic"} - LOSS: - name: simclr_info_nce_loss - simclr_info_nce_loss: - temperature: 0.1 - buffer_params: - embedding_dim: 128 - OPTIMIZER: - name: sgd - use_larc: True - larc_config: - clip: False - trust_coefficient: 0.001 - eps: 0.00000001 - weight_decay: 0.000001 - momentum: 0.9 - nesterov: False - num_epochs: 2 - regularize_bn: False - regularize_bias: True - head_optimizer_params: - use_different_lr: False - use_different_wd: False - param_schedulers: - lr: - auto_lr_scaling: - auto_scale: false - base_value: 0.3 - base_lr_batch_size: 256 - name: composite - schedulers: - - name: linear - start_value: 0.6 - end_value: 4.8 - - name: cosine_warm_restart - start_value: 4.8 - end_value: 0.0048 - # wave_type: half - # restart_interval_length: 0.5 - wave_type: full - is_adaptive: True - restart_interval_length: 0.334 - interval_scaling: [rescaled, rescaled] - update_interval: step - lengths: [0.1, 0.9] # 100ep - DISTRIBUTED: - BACKEND: nccl - NUM_NODES: 1 - NUM_PROC_PER_NODE: 1 - INIT_METHOD: tcp - RUN_ID: auto - MACHINE: - DEVICE: gpu - CHECKPOINT: - DIR: "." - AUTO_RESUME: True - CHECKPOINT_FREQUENCY: 1 - OVERWRITE_EXISTING: true - - TENSORBOARD_SETUP: - USE_TENSORBOARD: true diff --git a/launch_byol_1node.sh b/launch_byol_1node.sh deleted file mode 100644 index 4d1af4a3e..000000000 --- a/launch_byol_1node.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -./dev/launch_slurm.sh \ - config=pretrain/byol/byol_1node_resnet \ - config.SLURM.NAME=byol_test \ - config.SLURM.COMMENT="BYOL FOR VISSL" \ - config.SLURM.PARTITION=learnfair \ diff --git a/run_distributed_engines.py b/run_distributed_engines.py deleted file mode 100644 index d4b9c8b70..000000000 --- a/run_distributed_engines.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved - -""" -Wrapper to call torch.distributed.launch to run multi-gpu trainings. -Supports two engines: train and extract_features -""" - -import logging -import sys -import tempfile -from typing import Any, Callable, List - -import torch -from fvcore.common.file_io import PathManager -from hydra.experimental import compose, initialize_config_module -from vissl.data.dataset_catalog import get_data_files -from vissl.engines.extract_features import extract_main -from vissl.engines.train import train_main -from vissl.hooks import ClassyHook, default_hook_generator -from vissl.utils.checkpoint import ( - get_checkpoint_folder, - get_resume_checkpoint, - is_training_finished, -) -from vissl.utils.env import set_env_vars -from vissl.utils.hydra_config import AttrDict, convert_to_attrdict, is_hydra_available -from vissl.utils.io import cleanup_dir, copy_data_to_local -from vissl.utils.logger import setup_logging, shutdown_logging -from vissl.utils.misc import get_dist_run_id -from vissl.utils.slurm import get_node_id - - -def get_available_splits(cfg: AttrDict): - return [key for key in cfg.DATA if key.lower() in ["train", "test"]] - - -def copy_to_local(cfg: AttrDict): - available_splits = get_available_splits(cfg) - for split in available_splits: - if cfg.DATA[split].COPY_TO_LOCAL_DISK: - dest_dir = cfg.DATA[split]["COPY_DESTINATION_DIR"] - tmp_dest_dir = tempfile.mkdtemp() - data_files, label_files = get_data_files(split, cfg.DATA) - data_files.extend(label_files) - _, output_dir = copy_data_to_local( - data_files, dest_dir, tmp_destination_dir=tmp_dest_dir - ) - cfg.DATA[split]["COPY_DESTINATION_DIR"] = output_dir - - -def cleanup_local_dir(cfg: AttrDict): - available_splits = get_available_splits(cfg) - for split in available_splits: - if cfg.DATA[split].COPY_TO_LOCAL_DISK: - dest_dir = cfg.DATA[split]["COPY_DESTINATION_DIR"] - cleanup_dir(dest_dir) - - -def launch_distributed( - cfg: AttrDict, - node_id: int, - engine_name: str, - hook_generator: Callable[[Any], List[ClassyHook]], -): - """ - Launch the distributed training across gpus, according to the cfg - - Args: - cfg -- VISSL yaml configuration - node_id -- node_id for this node - engine_name -- what engine to run: train or extract_features - hook_generator -- Callback to generate all the ClassyVision hooks for this engine - """ - node_id = get_node_id(node_id) - dist_run_id = get_dist_run_id(cfg, cfg.DISTRIBUTED.NUM_NODES) - world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE - set_env_vars(local_rank=0, node_id=node_id, cfg=cfg) - copy_to_local(cfg) - - # given the checkpoint folder, we check that there's not already a final checkpoint - checkpoint_folder = get_checkpoint_folder(cfg) - if is_training_finished(cfg, checkpoint_folder=checkpoint_folder): - logging.info(f"Training already succeeded on node: {node_id}, exiting.") - return - - # Get the checkpoint where to load from. The load_checkpoints function will - # automatically take care of detecting whether it's a resume or not. - symlink_checkpoint_path = f"{checkpoint_folder}/checkpoint.torch" - if cfg.CHECKPOINT.USE_SYMLINK_CHECKPOINT_FOR_RESUME and PathManager.exists( - symlink_checkpoint_path - ): - checkpoint_path = f"{checkpoint_folder}/checkpoint.torch" - else: - checkpoint_path = get_resume_checkpoint( - cfg, checkpoint_folder=checkpoint_folder - ) - - try: - if world_size > 1: - torch.multiprocessing.spawn( - _distributed_worker, - nprocs=cfg.DISTRIBUTED.NUM_PROC_PER_NODE, - args=( - cfg, - node_id, - dist_run_id, - engine_name, - checkpoint_path, - checkpoint_folder, - hook_generator, - ), - daemon=False, - ) - else: - _distributed_worker( - local_rank=0, - cfg=cfg, - node_id=node_id, - dist_run_id=dist_run_id, - engine_name=engine_name, - checkpoint_path=checkpoint_path, - checkpoint_folder=checkpoint_folder, - hook_generator=hook_generator, - ) - - except (KeyboardInterrupt, RuntimeError) as e: - logging.error("Wrapping up, caught exception: ", e) - if isinstance(e, RuntimeError): - raise e - finally: - cleanup_local_dir(cfg) - - logging.info("All Done!") - - -def _distributed_worker( - local_rank: int, - cfg: AttrDict, - node_id: int, - dist_run_id: str, - engine_name: str, - checkpoint_path: str, - checkpoint_folder: str, - hook_generator: Callable[[Any], List[ClassyHook]], -): - dist_rank = cfg.DISTRIBUTED.NUM_PROC_PER_NODE * node_id + local_rank - if engine_name == "extract_features": - process_main = extract_main - else: - - def process_main(cfg, dist_run_id, local_rank, node_id): - train_main( - cfg, - dist_run_id, - checkpoint_path, - checkpoint_folder, - local_rank=local_rank, - node_id=node_id, - hook_generator=hook_generator, - ) - - logging.info( - f"Spawning process for node_id: {node_id}, local_rank: {local_rank}, " - f"dist_rank: {dist_rank}, dist_run_id: {dist_run_id}" - ) - process_main(cfg, dist_run_id, local_rank=local_rank, node_id=node_id) - - -def hydra_main(overrides: List[Any]): - print(f"####### overrides: {overrides}") - with initialize_config_module(config_module="vissl.config"): - cfg = compose("defaults", overrides=overrides) - setup_logging(__name__) - args, config = convert_to_attrdict(cfg) - launch_distributed( - config, - node_id=args.node_id, - engine_name=args.engine_name, - hook_generator=default_hook_generator, - ) - # close the logging streams including the filehandlers - shutdown_logging() - - -if __name__ == "__main__": - """ - Example usage: - - `python tools/run_distributed_engines.py config=test/integration_test/quick_simclr` - """ - overrides = sys.argv[1:] - assert is_hydra_available(), "Make sure to install hydra" - overrides.append("hydra.verbose=true") - hydra_main(overrides=overrides) diff --git a/vissl/data/ssl_transforms/img_pil_color_distortion.py b/vissl/data/ssl_transforms/img_pil_color_distortion.py index e3f79e4ca..df8f953b6 100644 --- a/vissl/data/ssl_transforms/img_pil_color_distortion.py +++ b/vissl/data/ssl_transforms/img_pil_color_distortion.py @@ -21,8 +21,16 @@ class ImgPilColorDistortion(ClassyTransform): randomly convert the image to grayscale. """ - def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8, - hue=0.2, color_jitter_probability=0.8, grayscale_probability=0.2): + def __init__( + self, + strength, + brightness=0.8, + contrast=0.8, + saturation=0.8, + hue=0.2, + color_jitter_probability=0.8, + grayscale_probability=0.2, + ): """ Args: strength (float): A number used to quantify the strength of the @@ -41,22 +49,23 @@ def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8, grayscale_probability (float): A floating point number used to quantify to apply randomly convert image to grayscale with the assigned probability. Default value is 0.2. - This function follows the Pytorch documentation: https://pytorch.org/vision/stable/transforms.html """ self.strength = strength self.brightness = brightness self.contrast = contrast self.saturation = saturation self.hue = hue - self.color_jitter_probability=color_jitter_probability - self.grayscale_probability=grayscale_probability + self.color_jitter_probability = color_jitter_probability + self.grayscale_probability = grayscale_probability self.color_jitter = pth_transforms.ColorJitter( self.brightness * self.strength, self.contrast * self.strength, self.saturation * self.strength, self.hue * self.strength, ) - self.rnd_color_jitter = pth_transforms.RandomApply([self.color_jitter], p=self.color_jitter_probability) + self.rnd_color_jitter = pth_transforms.RandomApply( + [self.color_jitter], p=self.color_jitter_probability + ) self.rnd_gray = pth_transforms.RandomGrayscale(p=self.grayscale_probability) self.transforms = pth_transforms.Compose([self.rnd_color_jitter, self.rnd_gray]) diff --git a/vissl/hooks/__init__.py b/vissl/hooks/__init__.py index 41e3276c6..9f19cbc8a 100644 --- a/vissl/hooks/__init__.py +++ b/vissl/hooks/__init__.py @@ -8,6 +8,7 @@ from classy_vision.hooks.classy_hook import ClassyHook from vissl.config import AttrDict +from vissl.hooks.byol_hooks import BYOLHook # noqa from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa from vissl.hooks.dino_hooks import DINOHook from vissl.hooks.grad_clip_hooks import GradClipHook # noqa @@ -21,15 +22,12 @@ ) from vissl.hooks.moco_hooks import MoCoHook # noqa from vissl.hooks.profiling_hook import ProfilingHook -from vissl.hooks.byol_hooks import BYOLHook # noqa - from vissl.hooks.state_update_hooks import ( # noqa CheckNanLossHook, FreezeParametersHook, SetDataSamplerEpochHook, SSLModelComplexityHook, ) -from vissl.hooks.byol_hooks import BYOLHook # noqa from vissl.hooks.swav_hooks import NormalizePrototypesHook # noqa from vissl.hooks.swav_hooks import SwAVUpdateQueueScoresHook # noqa from vissl.hooks.swav_momentum_hooks import ( @@ -149,14 +147,6 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]: ) ] ) - if cfg.LOSS.name == "byol_loss": - hooks.extend( - [ - BYOLHook( - cfg.LOSS["byol_loss"]["momentum"], - ) - ] - ) if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY: hooks.extend([SSLModelComplexityHook()]) if cfg.HOOKS.LOG_GPU_STATS: diff --git a/vissl/hooks/byol_hooks.py b/vissl/hooks/byol_hooks.py index 12c266184..c1536f928 100644 --- a/vissl/hooks/byol_hooks.py +++ b/vissl/hooks/byol_hooks.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import math import logging +import math import torch from classy_vision import tasks @@ -8,14 +8,15 @@ from vissl.models import build_model from vissl.utils.env import get_machine_local_and_dist_rank + class BYOLHook(ClassyHook): """ - BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733) - is based on Contrastive learning. This hook - creates a target network with the same architecture - as the main online network, but without the projection head. - The online network does not participate in backpropogation, - but instead is an exponential moving average of the online network. + BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733) + is based on Contrastive learning. This hook + creates a target network with the same architecture + as the main online network, but without the projection head. + The online network does not participate in backpropogation, + but instead is an exponential moving average of the online network. """ on_start = ClassyHook._noop @@ -28,7 +29,7 @@ class BYOLHook(ClassyHook): on_update = ClassyHook._noop @staticmethod - def cosine_decay(training_iter, max_iters, initial_value) -> float: + def cosine_decay(training_iter, max_iters, initial_value) -> float: """ For a given starting value, this function anneals the learning rate. @@ -42,8 +43,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float: """ Updates Exponential Moving average of the Target Network. """ - decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.) - return 1. - (1. - base_ema) * decay + decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.0) + return 1.0 - (1.0 - base_ema) * decay def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: """ @@ -53,19 +54,19 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: """ # Create the encoder, which will slowly track the model logging.info( - "BYOL: Building BYOL target network - rank %s %s", *get_machine_local_and_dist_rank() + "BYOL: Building BYOL target network - rank %s %s", + *get_machine_local_and_dist_rank(), ) - # Target model has the same architecture, but without the projector head. - target_model_config = task.config['MODEL'] - target_model_config['HEAD']['PARAMS'] = target_model_config['HEAD']['PARAMS'][0:1] + # Target model has the same architecture, *without* the projector head. + target_model_config = task.config["MODEL"] + target_model_config["HEAD"]["PARAMS"] = target_model_config["HEAD"]["PARAMS"][ + 0:1 + ] task.loss.target_network = build_model( target_model_config, task.config["OPTIMIZER"] ) - # TESTED: Target Network and Online network are properly created. - # TODO: Check SyncBatchNorm settings (low prior) - task.loss.target_network.to(task.device) # Restore an hypothetical checkpoint, else copy the model parameters from the @@ -73,7 +74,9 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: if task.loss.checkpoint is not None: task.loss.load_state_dict(task.loss.checkpoint) else: - logging.info("BYOL: Copying and freezing model parameters from online to target network") + logging.info( + "BYOL: Copying and freezing model parameters from online to target network" + ) for param_q, param_k in zip( task.base_model.parameters(), task.loss.target_network.parameters() ): @@ -92,7 +95,9 @@ def _update_momentum_coefficient(self, task: tasks.ClassyTask) -> None: self.total_iters = task.max_iteration logging.info(f"{self.total_iters} total iters") training_iteration = task.iteration - self.momentum = self.target_ema(training_iteration, self.base_momentum, self.total_iters) + self.momentum = self.target_ema( + training_iteration, self.base_momentum, self.total_iters + ) @torch.no_grad() def _update_target_network(self, task: tasks.ClassyTask) -> None: @@ -106,10 +111,10 @@ def _update_target_network(self, task: tasks.ClassyTask) -> None: task.base_model.parameters(), task.loss.target_network.parameters() ): target_params.data = ( - target_params.data * self.momentum + online_params.data * (1. - self.momentum) + target_params.data * self.momentum + + online_params.data * (1.0 - self.momentum) ) - @torch.no_grad() def on_forward(self, task: tasks.ClassyTask) -> None: """ @@ -127,9 +132,8 @@ def on_forward(self, task: tasks.ClassyTask) -> None: else: self._update_target_network(task) - # Compute target network embeddings - batch = task.last_batch.sample['input'] + batch = task.last_batch.sample["input"] target_embs = task.loss.target_network(batch)[0] # Save target embeddings to use them in the loss diff --git a/vissl/losses/byol_loss.py b/vissl/losses/byol_loss.py index b0fbdfcca..03581356b 100644 --- a/vissl/losses/byol_loss.py +++ b/vissl/losses/byol_loss.py @@ -7,9 +7,9 @@ import torch.nn.functional as F from classy_vision.losses import ClassyLoss, register_loss -_BYOLLossConfig = namedtuple( - "_BYOLLossConfig", ["embedding_dim", "momentum"] -) + +_BYOLLossConfig = namedtuple("_BYOLLossConfig", ["embedding_dim", "momentum"]) + def regression_loss(x, y): """ @@ -19,17 +19,16 @@ def regression_loss(x, y): Cosine similarity. This implementation uses Cosine similarity. """ normed_x, normed_y = F.normalize(x, dim=1), F.normalize(y, dim=1) - return torch.sum((normed_x - normed_y).pow(2), dim=1) + # Euclidean Distance squared. + return 2 - 2 * (normed_x * normed_y).sum(dim=1) class BYOLLossConfig(_BYOLLossConfig): - """ Settings for the BYOL loss""" + """Settings for the BYOL loss""" @staticmethod def defaults() -> "BYOLLossConfig": - return BYOLLossConfig( - embedding_dim=256, momentum=0.999 - ) + return BYOLLossConfig(embedding_dim=256, momentum=0.999) @register_loss("byol_loss") @@ -68,7 +67,9 @@ def from_config(cls, config: BYOLLossConfig) -> "BYOLLoss": """ return cls(config) - def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor: + def forward( + self, online_network_prediction: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: """ In this function, the Online Network receives the tensor as input after projection and they make predictions on the output of the target network’s projection, @@ -79,7 +80,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t compute the cross entropy loss for this batch. Args: - query: output of the encoder given the current batch online_network_prediction: online model output. this is a prediction of the target network output. @@ -91,8 +91,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t online_view1, online_view2 = torch.chunk(online_network_prediction, 2, 0) target_view1, target_view2 = torch.chunk(self.target_embs.detach(), 2, 0) - # TESTED: Views are received correctly. - # Compute losses loss1 = regression_loss(online_view1, target_view2) loss2 = regression_loss(online_view2, target_view1) @@ -111,7 +109,6 @@ def load_state_dict(self, state_dict, *args, **kwargs) -> None: Args: state_dict (serialized via torch.save) """ - # If the encoder has been allocated, use the normal pytorch restoration if self.target_network is None: self.checkpoint = state_dict diff --git a/vissl/trainer/train_sdp_task.py b/vissl/trainer/train_sdp_task.py new file mode 100644 index 000000000..cadccbf77 --- /dev/null +++ b/vissl/trainer/train_sdp_task.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from classy_vision.optim.zero import ZeRO +from classy_vision.tasks import register_task +from classy_vision.tasks.classification_task import BroadcastBuffersMode +from fairscale.nn.data_parallel import ShardedDataParallel +from vissl.config import AttrDict +from vissl.trainer.train_task import SelfSupervisionTask + + +# More information on ShardedDDP can be found in the Fairscale repository +# https://github.com/facebookresearch/fairscale + + +@register_task("self_supervision_sdp_task") +class SelfSupervisionSDPTask(SelfSupervisionTask): + def __init__(self, config: AttrDict): + super().__init__(config) + + def init_distributed_data_parallel_model(self): + """ + Initialize ShardedDataParallel, needed for sharded distributed training. + This is where a model should be wrapped by DDP. + """ + broadcast_buffers = ( + self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS + ) + + # Replace the original DDP wrap by the shard-aware ShardedDDP + # we use the fairscale reduce_buffer_size by default however, if user sets it to + # some different value, we use the different value. + reduce_buffer_size = 2 ** 23 + if self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size >= 0: + reduce_buffer_size = self.config.MODEL.SHARDED_DDP_SETUP.reduce_buffer_size + logging.info(f"Setting reduce_buffer_size: {reduce_buffer_size}") + if isinstance(self.optimizer, ZeRO): + logging.info("Using ShardedDDP") + self.distributed_model = ShardedDataParallel( + module=self.base_model, + sharded_optimizer=self.optimizer.optimizer, + broadcast_buffers=broadcast_buffers, + reduce_buffer_size=reduce_buffer_size, + ) + else: + raise NotImplementedError( + "This DataParallel engine should only be used in conjunction with ZeRO" + )