-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SwaV self-supervision method #197
base: main
Are you sure you want to change the base?
Conversation
multi_crop_features = [self.forward(crop) for crop in crops] | ||
high_resolution = multi_crop_features[:self.num_highres_crops] | ||
low_resolution = multi_crop_features[self.num_highres_crops:] | ||
output = {"high_resolution": high_resolution, "low_resolution": low_resolution} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's rename to high_resolution_embs
, low_resolution_logits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
torchok/data/transforms/swav.py
Outdated
cj_contrast: How much to jitter contrast. | ||
cj_sat: How much to jitter saturation. | ||
cj_hue: How much to jitter hue. | ||
random_gray_scale: Probability of conversion to grayscale. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be to rename in the same fashion, i.e. gray_prob
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
torchok/data/transforms/swav.py
Outdated
cj_sat: How much to jitter saturation. | ||
cj_hue: How much to jitter hue. | ||
random_gray_scale: Probability of conversion to grayscale. | ||
gaussian_blur: Probability of Gaussian blur. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gb_prob
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
torchok/data/transforms/utils.py
Outdated
"Length of crop_sizes and crop_min_scales must be equal but are" | ||
f" {len(crop_sizes)} and {len(crop_min_scales)}." | ||
) | ||
if len(crop_sizes) != len(crop_min_scales): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There must be crop_max_scales
, fixed
|
||
for j in range(num_views): | ||
cross_batch_view = default_collate([sample_views[j] for sample_views in batch]) | ||
for k, v in cross_batch_view.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please, rename k
and v
within this class to easier read the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
|
||
return new_samples | ||
|
||
def collate_fn(self, batch: List[List[Dict[str, Any]]]) -> Dict[str, List[torch.Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Describe what is on input and on output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
p = self.prototypes(x) | ||
return p | ||
|
||
def forward_with_gt(self, batch: Dict[str, Tensor]) -> Dict[str, List[Tensor]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the batch
typing correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
|
||
@TASKS.register_class | ||
class SwaVTask(ClassificationTask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add an example config for this task?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add later when i get the same results as lightly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
No description provided.