Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ep/add fixres resnext101 32x48d wsl #1103

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['fixres_resnext101_32x48d_wsl'] = lambda: ModelCommitment(identifier='fixres_resnext101_32x48d_wsl',
activations_model=get_model(),
layers=get_layers('fixres_resnext101_32x48d_wsl'))
57 changes: 57 additions & 0 deletions brainscore_vision/models/fixres_resnext101_32x48d_wsl/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from fixres.hubconf import load_state_dict_from_url
from fixres.transforms_v2 import get_transforms
from model_helpers.activations.pytorch import load_images
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from model_helpers.activations.pytorch import load_images
from brainscore_vision.model_helpers.activations.pytorch import load_images

import numpy as np
from importlib import import_module
import ssl


ssl._create_default_https_context = ssl._create_unverified_context


def get_model():
module = import_module('fixres.imnet_evaluate.resnext_wsl')
model_ctr = getattr(module, 'resnext101_32x48d_wsl')
model = model_ctr(pretrained=False) # the pretrained flag here corresponds to standard resnext weights
pretrained_dict = load_state_dict_from_url('https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNeXt_101_32x48d.pth',
map_location=lambda storage, loc: storage)['model']
model_dict = model.state_dict()
for k in model_dict.keys():
assert ('module.' + k) in pretrained_dict.keys()
model_dict[k] = pretrained_dict.get(('module.' + k))
model.load_state_dict(model_dict)

# preprocessing
# 320 for ResNeXt:
# https://github.com/mschrimpf/FixRes/tree/4ddcf11b29c118dfb8a48686f75f572450f67e5d#example-evaluation-procedure
input_size = 320
# https://github.com/mschrimpf/FixRes/blob/0dc15ab509b9cb9d7002ca47826dab4d66033668/fixres/imnet_evaluate/train.py#L159-L160
transformation = get_transforms(input_size=input_size, test_size=input_size,
kind='full', need=('val',),
# this is different from standard ImageNet evaluation to show the whole image
crop=False,
# no backbone parameter for ResNeXt following
# https://github.com/mschrimpf/FixRes/blob/0dc15ab509b9cb9d7002ca47826dab4d66033668/fixres/imnet_evaluate/train.py#L154-L156
backbone=None)
transform = transformation['val']

def load_preprocess_images(image_filepaths):
images = load_images(image_filepaths)
images = [transform(image) for image in images]
images = [image.unsqueeze(0) for image in images]
images = np.concatenate(images)
return images

wrapper = PytorchWrapper(identifier='resnext101_32x48d_wsl', model=model, preprocessing=load_preprocess_images,
batch_size=4) # doesn't fit into 12 GB GPU memory otherwise
wrapper.image_size = input_size
return wrapper


def get_layers(name):
return (['conv1'] +
# note that while relu is used multiple times, by default the last one will overwrite all previous ones
[f"layer{block + 1}.{unit}.relu"
for block, block_units in enumerate([3, 4, 23, 3]) for unit in range(block_units)] +
['avgpool'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
torchvision
torch
numpy
importlib
Fixing-the-train-test-resolution-discrepancy-scripts@ git+https://github.com/mschrimpf/FixRes.git
7 changes: 7 additions & 0 deletions brainscore_vision/models/fixres_resnext101_32x48d_wsl/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest
import brainscore_vision

@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('fixres_resnext101_32x48d_wsl')
assert model.identifier == 'fixres_resnext101_32x48d_wsl'