Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
Included feedback on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavsinghps1 authored and iseessel committed Oct 17, 2021
1 parent 99dfae9 commit 6e3063d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ config:
WEIGHTS_INIT:
PARAMS_FILE: "specify the model weights"
STATE_DICT_KEY_NAME: classy_state_dict
# STATE_DICT_KEY_NAME: model_state_dict
SYNC_BN_CONFIG:
CONVERT_BN_TO_SYNC_BN: True
SYNC_BN_TYPE: apex
Expand All @@ -92,7 +91,6 @@ config:
ignore_index: -1
OPTIMIZER:
name: sgd
# In the OSS Caffe2 benchmark, RN50 models use 1e-4 and AlexNet models 5e-4
momentum: 0.9
num_epochs: 80
nesterov: True
Expand Down
4 changes: 2 additions & 2 deletions vissl/data/ssl_transforms/img_pil_color_distortion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class ImgPilColorDistortion(ClassyTransform):
randomly convert the image to grayscale.
"""

def __init__(self, strength , brightness=0.8 , contrast=0.8 , saturation=0.8,
hue=0.2,color_jitter_probability=0.8,grayscale_probability=0.2):
def __init__(self, strength, brightness=0.8, contrast=0.8, saturation=0.8,
hue=0.2, color_jitter_probability=0.8, grayscale_probability=0.2):
"""
Args:
strength (float): A number used to quantify the strength of the
Expand Down
17 changes: 8 additions & 9 deletions vissl/hooks/byol_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

class BYOLHook(ClassyHook):
"""
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
is based on Contrastive learning, this hook
creates a target network with architecture similar to
Online network but without the projector head and parameters
an exponential moving average of the online network's parameters,
these two networks interact and learn from each other.
BYOL - Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
is based on Contrastive learning. This hook
creates a target network with the same architecture
as the main online network, but without the projection head.
The online network does not participate in backpropogation,
but instead is an exponential moving average of the online network.
"""

on_start = ClassyHook._noop
Expand All @@ -30,7 +30,7 @@ class BYOLHook(ClassyHook):
@staticmethod
def cosine_decay(training_iter, max_iters, initial_value) -> float:
"""
For a given starting value, this fucntion anneals the learning
For a given starting value, this function anneals the learning
rate.
"""
training_iter = min(training_iter, max_iters)
Expand All @@ -48,9 +48,8 @@ def target_ema(training_iter, base_ema, max_iters) -> float:
def _build_byol_target_network(self, task: tasks.ClassyTask) -> None:
"""
Creates a "Target Network" which has the same architecture as the
Online Network but without the projector head and its network parameters
Online Network but without the projection head. Its network parameters
are a lagging exponential moving average of the online model's parameters.
"""
# Create the encoder, which will slowly track the model
logging.info(
Expand Down
12 changes: 10 additions & 2 deletions vissl/losses/byol_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class BYOLLoss(ClassyLoss):
This is the loss proposed in BYOL
- Bootstrap your own latent: (https://arxiv.org/abs/2006.07733)
This class wraps functions which computes
- loss
- loss : BYOL uses contrastive loss which is the difference in
l2-normalized Online network's prediction and Target
network's projections or cosine similarity between the two.
In this implementation we have used Cosine similarity.
- restores loss from checkpoints.
Config params:
Expand Down Expand Up @@ -67,8 +70,13 @@ def from_config(cls, config: BYOLLossConfig) -> "BYOLLoss":

def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor:
"""
In this function, the Online Network receives the tensor as input after projection
and they make predictions on the output of the target network’s projection,
The similarity between the two is computed and then a mean of it is used to
update the parameters of both the networks to reduce loss.
Given the encoder queries, the key and the queue of the previous queries,
compute the cross entropy loss for this batch
compute the cross entropy loss for this batch.
Args:
query: output of the encoder given the current batch
Expand Down

0 comments on commit 6e3063d

Please sign in to comment.