From ac82e1be9129275e6076e9023dfb4b63fe4a1d7e Mon Sep 17 00:00:00 2001 From: Jenkins Date: Sat, 30 Nov 2024 19:37:40 -0500 Subject: [PATCH] add barlow_twins_custom to models --- .../models/barlow_twins_custom/__init__.py | 5 +++ .../models/barlow_twins_custom/model.py | 44 +++++++++++++++++++ .../barlow_twins_custom/requirements.txt | 5 +++ .../models/barlow_twins_custom/test.py | 12 +++++ 4 files changed, 66 insertions(+) create mode 100644 brainscore_vision/models/barlow_twins_custom/__init__.py create mode 100644 brainscore_vision/models/barlow_twins_custom/model.py create mode 100644 brainscore_vision/models/barlow_twins_custom/requirements.txt create mode 100644 brainscore_vision/models/barlow_twins_custom/test.py diff --git a/brainscore_vision/models/barlow_twins_custom/__init__.py b/brainscore_vision/models/barlow_twins_custom/__init__.py new file mode 100644 index 000000000..e47ab5797 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/__init__.py @@ -0,0 +1,5 @@ +from brainscore_vision import model_registry +from .model import get_model + +# Register the Barlow Twins model with custom weights +model_registry['barlow_twins_custom'] = lambda: get_model('barlow_twins_custom') diff --git a/brainscore_vision/models/barlow_twins_custom/model.py b/brainscore_vision/models/barlow_twins_custom/model.py new file mode 100644 index 000000000..496376228 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/model.py @@ -0,0 +1,44 @@ +import torch +from pathlib import Path +from torchvision.models import resnet18 +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +from brainscore_vision.model_helpers.check_submission import check_models +from collections import OrderedDict +from urllib.request import urlretrieve +import functools + +# Custom model loader +def get_model(name): + assert name == 'barlow_twins_custom' + + url = " https://www.dropbox.com/scl/fi/db5yp3hols5sucujanimx/barlow_twins_weights.pth?rlkey=nalge9jixfeqorazwu4xqdbd8&st=yqf3qkaj&dl=1" + fh = urlretrieve(url) + state_dict = torch.load(fh[0], map_location=torch.device("cpu")) + model = resnet18(pretrained=False) + model.load_state_dict(state_dict, strict=False) + print(model) + preprocessing = functools.partial(load_preprocess_images, image_size=224) + + activations_model = PytorchWrapper(identifier='barlow_twins_custom', model=model, preprocessing=preprocessing) + + + return ModelCommitment( + identifier='barlow_twins_custom', + activations_model=activations_model, + layers=['layer1', 'layer2', 'layer3', 'layer4', 'avgpool'] + ) + +def get_model_list(): + return ['barlow_twins_custom'] + +# Specify layers to test +def get_layers(name): + assert name == 'barlow_twins_custom' + return ['layer1', 'layer2', 'layer3', 'layer4', 'avgpool'] + + +if __name__ == '__main__': + + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/barlow_twins_custom/requirements.txt b/brainscore_vision/models/barlow_twins_custom/requirements.txt new file mode 100644 index 000000000..4feb6a389 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/requirements.txt @@ -0,0 +1,5 @@ +torch +torchvision +requests +pathlib +urlretrieve \ No newline at end of file diff --git a/brainscore_vision/models/barlow_twins_custom/test.py b/brainscore_vision/models/barlow_twins_custom/test.py new file mode 100644 index 000000000..d9576c644 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/test.py @@ -0,0 +1,12 @@ +import pytest +import brainscore_vision + +@pytest.mark.travis_slow +def test_barlow_twins_custom(): + model = brainscore_vision.load_model('barlow_twins_custom') + assert model.identifier == 'barlow_twins_custom' + + + +# AssertionError: No registrations found for resnet18_random +# ⚡ master ~/vision python -m brainscore_vision score --model_identifier='resnet50_tutorial' --benchmark_identifier='MajajHong2015public.IT-pls' \ No newline at end of file