Skip to content

Commit

Permalink
Add model resnext101_32x8d_wsl (#834)
Browse files Browse the repository at this point in the history
* adds model resnext101_32x8d_wsl

* addresses PR comments
  • Loading branch information
mike-ferguson authored May 16, 2024
1 parent d74b0d1 commit 4043a0d
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 0 deletions.
7 changes: 7 additions & 0 deletions brainscore_vision/models/resnext101_32x8d_wsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['resnext101_32x8d_wsl'] = lambda: ModelCommitment(identifier='resnext101_32x8d_wsl',
activations_model=get_model('resnext101_32x8d_wsl'),
layers=get_layers('resnext101_32x8d_wsl'))
47 changes: 47 additions & 0 deletions brainscore_vision/models/resnext101_32x8d_wsl/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import functools
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
from brainscore_vision.model_helpers.check_submission import check_models
import torch.hub
import ssl


ssl._create_default_https_context = ssl._create_unverified_context


def get_model(name):
assert name == 'resnext101_32x8d_wsl'
model_identifier = "resnext101_32x8d_wsl"
model = torch.hub.load('facebookresearch/WSL-Images', model_identifier)
preprocessing = functools.partial(load_preprocess_images, image_size=224)
batch_size = {8: 32, 16: 16, 32: 8, 48: 4}
wrapper = PytorchWrapper(identifier=model_identifier, model=model, preprocessing=preprocessing,
batch_size=batch_size[8])
wrapper.image_size = 224
return wrapper


def get_layers(name):
assert name == 'resnext101_32x8d_wsl'
return (['conv1'] +
# note that while relu is used multiple times, by default the last one will overwrite all previous ones
[f"layer{block + 1}.{unit}.relu"
for block, block_units in enumerate([3, 4, 23, 3]) for unit in range(block_units)] +
['avgpool'])


def get_bibtex(model_identifier):
"""
A method returning the bibtex reference of the requested model as a string.
"""
return """@inproceedings{mahajan2018exploring,
title={Exploring the limits of weakly supervised pretraining},
author={Mahajan, Dhruv and Girshick, Ross and Ramanathan, Vignesh and He, Kaiming and Paluri, Manohar and Li, Yixuan and Bharambe, Ashwin and Van Der Maaten, Laurens},
booktitle={Proceedings of the European conference on computer vision (ECCV)},
pages={181--196},
year={2018}
}"""


if __name__ == '__main__':
check_models.check_base_models(__name__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torchvision
torch
ssl
functools
8 changes: 8 additions & 0 deletions brainscore_vision/models/resnext101_32x8d_wsl/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
import brainscore_vision


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('resnext101_32x8d_wsl')
assert model.identifier == 'resnext101_32x8d_wsl'

0 comments on commit 4043a0d

Please sign in to comment.