Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
BYOL improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
iseessel committed Oct 17, 2021
1 parent 6e3063d commit a4047e9
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ config:
TRANSFORMS:
- name: RandomResizedCrop
size: 224
interpolation: 3
- name: RandomHorizontalFlip
- name: ToTensor
- name: Normalize
Expand All @@ -38,6 +39,7 @@ config:
TRANSFORMS:
- name: Resize
size: 256
interpolation: 3
- name: CenterCrop
size: 224
- name: ToTensor
Expand Down Expand Up @@ -82,7 +84,7 @@ config:
PARAMS_FILE: "specify the model weights"
STATE_DICT_KEY_NAME: classy_state_dict
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
CONVERT_BN_TO_SYNC_BN: False
SYNC_BN_TYPE: apex
GROUP_SIZE: 8
LOSS:
Expand All @@ -93,22 +95,29 @@ config:
name: sgd
momentum: 0.9
num_epochs: 80
weight_decay: 0
nesterov: True
regularize_bn: False
regularize_bias: True
param_schedulers:
lr:
auto_lr_scaling:
auto_scale: true
base_value: 0.4
# if set to True, learning rate will be scaled.
auto_scale: True
# base learning rate value that will be scaled.
base_value: 0.2
# batch size for which the base learning rate is specified. The current batch size
# is used to determine how to scale the base learning rate value.
# scaled_lr = ((batchsize_per_gpu * world_size) * base_value ) / base_lr_batch_size
base_lr_batch_size: 256
name: multistep
values: [0.4, 0.3, 0.2, 0.1, 0.05]
milestones: [16, 32, 48, 64]
update_interval: epoch
# scaling_type can be set to "sqrt" to reduce the impact of scaling on the base value
scaling_type: "linear"
name: constant
update_interval: "epoch"
value: 0.2
DISTRIBUTED:
BACKEND: nccl
NUM_NODES: 8
NUM_NODES: 4
NUM_PROC_PER_NODE: 8
INIT_METHOD: tcp
RUN_ID: auto
Expand Down
9 changes: 5 additions & 4 deletions configs/config/pretrain/byol/byol_8node_resnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ config:
RESNETS:
DEPTH: 50
ZERO_INIT_RESIDUAL: True
HEAD:
HEAD:
PARAMS: [
["mlp", {"dims": [2048, 4096, 256], "use_relu": True, "use_bn": True}],
["mlp", {"dims": [256, 4096, 256], "use_relu": True, "use_bn": True}]
Expand All @@ -82,15 +82,16 @@ config:
byol_loss:
embedding_dim: 256
momentum: 0.99
OPTIMIZER: # from official BYOL implementation, deepmind-research/byol/configs/byol.py
OPTIMIZER:
name: lars
trust_coefficient: 0.001
eta: 0.001
weight_decay: 1.0e-6
momentum: 0.9
nesterov: False
num_epochs: 300
regularize_bn: False
regularize_bias: True
regularize_bias: False
exclude_bias_and_norm: True
param_schedulers:
lr:
auto_lr_scaling:
Expand Down
21 changes: 15 additions & 6 deletions vissl/data/ssl_transforms/img_pil_color_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ class ImgPilColorDistortion(ClassyTransform):
randomly convert the image to grayscale.
"""

def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8,
hue=0.2, color_jitter_probability=0.8, grayscale_probability=0.2):
def __init__(
self,
strength,
brightness=0.8,
contrast=0.8,
saturation=0.8,
hue=0.2,
color_jitter_probability=0.8,
grayscale_probability=0.2,
):
"""
Args:
strength (float): A number used to quantify the strength of the
Expand All @@ -41,22 +49,23 @@ def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8,
grayscale_probability (float): A floating point number used to
quantify to apply randomly convert image to grayscale with
the assigned probability. Default value is 0.2.
This function follows the Pytorch documentation: https://pytorch.org/vision/stable/transforms.html
"""
self.strength = strength
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
self.color_jitter_probability=color_jitter_probability
self.grayscale_probability=grayscale_probability
self.color_jitter_probability = color_jitter_probability
self.grayscale_probability = grayscale_probability
self.color_jitter = pth_transforms.ColorJitter(
self.brightness * self.strength,
self.contrast * self.strength,
self.saturation * self.strength,
self.hue * self.strength,
)
self.rnd_color_jitter = pth_transforms.RandomApply([self.color_jitter], p=self.color_jitter_probability)
self.rnd_color_jitter = pth_transforms.RandomApply(
[self.color_jitter], p=self.color_jitter_probability
)
self.rnd_gray = pth_transforms.RandomGrayscale(p=self.grayscale_probability)
self.transforms = pth_transforms.Compose([self.rnd_color_jitter, self.rnd_gray])

Expand Down
12 changes: 1 addition & 11 deletions vissl/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from classy_vision.hooks.classy_hook import ClassyHook
from vissl.config import AttrDict
from vissl.hooks.byol_hooks import BYOLHook # noqa
from vissl.hooks.deepclusterv2_hooks import ClusterMemoryHook, InitMemoryHook # noqa
from vissl.hooks.dino_hooks import DINOHook
from vissl.hooks.grad_clip_hooks import GradClipHook # noqa
Expand All @@ -21,15 +22,12 @@
)
from vissl.hooks.moco_hooks import MoCoHook # noqa
from vissl.hooks.profiling_hook import ProfilingHook
from vissl.hooks.byol_hooks import BYOLHook # noqa

from vissl.hooks.state_update_hooks import ( # noqa
CheckNanLossHook,
FreezeParametersHook,
SetDataSamplerEpochHook,
SSLModelComplexityHook,
)
from vissl.hooks.byol_hooks import BYOLHook # noqa
from vissl.hooks.swav_hooks import NormalizePrototypesHook # noqa
from vissl.hooks.swav_hooks import SwAVUpdateQueueScoresHook # noqa
from vissl.hooks.swav_momentum_hooks import (
Expand Down Expand Up @@ -149,14 +147,6 @@ def default_hook_generator(cfg: AttrDict) -> List[ClassyHook]:
)
]
)
if cfg.LOSS.name == "byol_loss":
hooks.extend(
[
BYOLHook(
cfg.LOSS["byol_loss"]["momentum"],
)
]
)
if cfg.HOOKS.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY:
hooks.extend([SSLModelComplexityHook()])
if cfg.HOOKS.LOG_GPU_STATS:
Expand Down
50 changes: 27 additions & 23 deletions vissl/hooks/byol_hooks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import logging
import math

import torch
from classy_vision import tasks
from classy_vision.hooks.classy_hook import ClassyHook
from vissl.models import build_model
from vissl.utils.env import get_machine_local_and_dist_rank


class BYOLHook(ClassyHook):
"""
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
is based on Contrastive learning. This hook
creates a target network with the same architecture
as the main online network, but without the projection head.
The online network does not participate in backpropogation,
but instead is an exponential moving average of the online network.
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
is based on Contrastive learning. This hook
creates a target network with the same architecture
as the main online network, but without the projection head.
The online network does not participate in backpropogation,
but instead is an exponential moving average of the online network.
"""

on_start = ClassyHook._noop
Expand All @@ -28,7 +29,7 @@ class BYOLHook(ClassyHook):
on_update = ClassyHook._noop

@staticmethod
def cosine_decay(training_iter, max_iters, initial_value) -> float:
def cosine_decay(training_iter, max_iters, initial_value) -> float:
"""
For a given starting value, this function anneals the learning
rate.
Expand All @@ -42,8 +43,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float:
"""
Updates Exponential Moving average of the Target Network.
"""
decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.)
return 1. - (1. - base_ema) * decay
decay = BYOLHook.cosine_decay(training_iter, max_iters, 1.0)
return 1.0 - (1.0 - base_ema) * decay

def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
"""
Expand All @@ -53,27 +54,29 @@ def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
"""
# Create the encoder, which will slowly track the model
logging.info(
"BYOL: Building BYOL target network - rank %s %s", *get_machine_local_and_dist_rank()
"BYOL: Building BYOL target network - rank %s %s",
*get_machine_local_and_dist_rank(),
)

# Target model has the same architecture, but without the projector head.
target_model_config = task.config['MODEL']
target_model_config['HEAD']['PARAMS'] = target_model_config['HEAD']['PARAMS'][0:1]
# Target model has the same architecture, *without* the projector head.
target_model_config = task.config["MODEL"]
target_model_config["HEAD"]["PARAMS"] = target_model_config["HEAD"]["PARAMS"][
0:1
]
task.loss.target_network = build_model(
target_model_config, task.config["OPTIMIZER"]
)

# TESTED: Target Network and Online network are properly created.
# TODO: Check SyncBatchNorm settings (low prior)

task.loss.target_network.to(task.device)

# Restore an hypothetical checkpoint, else copy the model parameters from the
# online network.
if task.loss.checkpoint is not None:
task.loss.load_state_dict(task.loss.checkpoint)
else:
logging.info("BYOL: Copying and freezing model parameters from online to target network")
logging.info(
"BYOL: Copying and freezing model parameters from online to target network"
)
for param_q, param_k in zip(
task.base_model.parameters(), task.loss.target_network.parameters()
):
Expand All @@ -92,7 +95,9 @@ def _update_momentum_coefficient(self, task: tasks.ClassyTask) -> None:
self.total_iters = task.max_iteration
logging.info(f"{self.total_iters} total iters")
training_iteration = task.iteration
self.momentum = self.target_ema(training_iteration, self.base_momentum, self.total_iters)
self.momentum = self.target_ema(
training_iteration, self.base_momentum, self.total_iters
)

@torch.no_grad()
def _update_target_network(self, task: tasks.ClassyTask) -> None:
Expand All @@ -106,10 +111,10 @@ def _update_target_network(self, task: tasks.ClassyTask) -> None:
task.base_model.parameters(), task.loss.target_network.parameters()
):
target_params.data = (
target_params.data * self.momentum + online_params.data * (1. - self.momentum)
target_params.data * self.momentum
+ online_params.data * (1.0 - self.momentum)
)


@torch.no_grad()
def on_forward(self, task: tasks.ClassyTask) -> None:
"""
Expand All @@ -127,9 +132,8 @@ def on_forward(self, task: tasks.ClassyTask) -> None:
else:
self._update_target_network(task)


# Compute target network embeddings
batch = task.last_batch.sample['input']
batch = task.last_batch.sample["input"]
target_embs = task.loss.target_network(batch)[0]

# Save target embeddings to use them in the loss
Expand Down
23 changes: 10 additions & 13 deletions vissl/losses/byol_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import torch.nn.functional as F
from classy_vision.losses import ClassyLoss, register_loss

_BYOLLossConfig = namedtuple(
"_BYOLLossConfig", ["embedding_dim", "momentum"]
)

_BYOLLossConfig = namedtuple("_BYOLLossConfig", ["embedding_dim", "momentum"])


def regression_loss(x, y):
"""
Expand All @@ -19,17 +19,16 @@ def regression_loss(x, y):
Cosine similarity. This implementation uses Cosine similarity.
"""
normed_x, normed_y = F.normalize(x, dim=1), F.normalize(y, dim=1)
return torch.sum((normed_x - normed_y).pow(2), dim=1)
# Euclidean Distance squared.
return 2 - 2 * (normed_x * normed_y).sum(dim=1)


class BYOLLossConfig(_BYOLLossConfig):
""" Settings for the BYOL loss"""
"""Settings for the BYOL loss"""

@staticmethod
def defaults() -> "BYOLLossConfig":
return BYOLLossConfig(
embedding_dim=256, momentum=0.999
)
return BYOLLossConfig(embedding_dim=256, momentum=0.999)


@register_loss("byol_loss")
Expand Down Expand Up @@ -68,7 +67,9 @@ def from_config(cls, config: BYOLLossConfig) -> "BYOLLoss":
"""
return cls(config)

def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor:
def forward(
self, online_network_prediction: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""
In this function, the Online Network receives the tensor as input after projection
and they make predictions on the output of the target network’s projection,
Expand All @@ -79,7 +80,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t
compute the cross entropy loss for this batch.
Args:
query: output of the encoder given the current batch
online_network_prediction: online model output. this is a prediction of the
target network output.
Expand All @@ -91,8 +91,6 @@ def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> t
online_view1, online_view2 = torch.chunk(online_network_prediction, 2, 0)
target_view1, target_view2 = torch.chunk(self.target_embs.detach(), 2, 0)

# TESTED: Views are received correctly.

# Compute losses
loss1 = regression_loss(online_view1, target_view2)
loss2 = regression_loss(online_view2, target_view1)
Expand All @@ -111,7 +109,6 @@ def load_state_dict(self, state_dict, *args, **kwargs) -> None:
Args:
state_dict (serialized via torch.save)
"""

# If the encoder has been allocated, use the normal pytorch restoration
if self.target_network is None:
self.checkpoint = state_dict
Expand Down

0 comments on commit a4047e9

Please sign in to comment.