Skip to content

Commit

Permalink
Add resnet50_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Ethan Pellegrini committed Jul 30, 2024
1 parent a726eb1 commit 1102c40
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 0 deletions.
8 changes: 8 additions & 0 deletions brainscore_vision/models/resnet50_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers


model_registry['resnet50_v2'] = lambda: ModelCommitment(identifier='resnet50_v2',
activations_model=get_model(),
layers=get_layers())
35 changes: 35 additions & 0 deletions brainscore_vision/models/resnet50_v2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import functools
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
import torchvision
import ssl


ssl._create_default_https_context = ssl._create_unverified_context

'''
This is a Pytorch implementation of resnet50.
The model template can be found at the following URL:
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
'''

MODEL = torchvision.models.resnet50(weights='ResNet50_Weights.IMAGENET1K_V2') # use V2 weights


def get_model():
model_identifier = "resnet50_v2"
preprocessing = functools.partial(load_preprocess_images, image_size=224)
wrapper = PytorchWrapper(identifier=model_identifier, model=MODEL, preprocessing=preprocessing)
wrapper.image_size = 224
return wrapper


def get_layers():
layer_names = []

for name, module in MODEL.named_modules():
layer_names.append(name)

return layer_names[2:]
2 changes: 2 additions & 0 deletions brainscore_vision/models/resnet50_v2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
torchvision
8 changes: 8 additions & 0 deletions brainscore_vision/models/resnet50_v2/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
import brainscore_vision


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('resnet50_v2')
assert model.identifier == 'resnet50_v2'

0 comments on commit 1102c40

Please sign in to comment.