Skip to content

Commit

Permalink
add regnet to models
Browse files Browse the repository at this point in the history
  • Loading branch information
AutoJenkins committed Jan 8, 2024
1 parent bcf2ab1 commit 438071b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
14 changes: 14 additions & 0 deletions brainscore_vision/models/regnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, LAYERS

BIBTEX = """@inproceedings{radosavovic2020designing,
title={Designing network design spaces},
author={Radosavovic, Ilija and Kosaraju, Raj Prateek and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={10428--10436},
year={2020}
}"""

model_registry['regnet_y_400mf'] = lambda: ModelCommitment(
identifier='regnet_y_400mf', activations_model=get_model(), layers=LAYERS)
17 changes: 17 additions & 0 deletions brainscore_vision/models/regnet/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import functools

import torchvision.models

from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images

# these layer choices were not investigated in any depth, we blindly picked all high-level blocks
LAYERS = ['trunk_output.block1', 'trunk_output.block2', 'trunk_output.block3', 'trunk_output.block4']


def get_model():
model = torchvision.models.regnet_y_400mf(pretrained=True)
preprocessing = functools.partial(load_preprocess_images, image_size=224)
wrapper = PytorchWrapper(identifier='regnet_y_400mf', model=model, preprocessing=preprocessing)
wrapper.image_size = 224
return wrapper
17 changes: 17 additions & 0 deletions brainscore_vision/models/regnet/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import logging
import sys

import pytest
from pytest import approx

from brainscore_vision import score

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)


@pytest.mark.travis_slow
@pytest.mark.memory_intense
def test_score():
actual_score = score(model_identifier="regnet_y_400mf", benchmark_identifier="MajajHong2015public.IT-pls",
conda_active=True)
assert actual_score == approx(0.532, abs=0.0005)

0 comments on commit 438071b

Please sign in to comment.