From be83d71a3ab6dcba00c84922f92ce2269c8d5292 Mon Sep 17 00:00:00 2001 From: Jenkins Date: Tue, 3 Dec 2024 18:11:31 -0500 Subject: [PATCH] add barlow_twins_custom to models --- .../models/barlow_twins_custom/__init__.py | 5 ++ .../models/barlow_twins_custom/model.py | 58 +++++++++++++++++++ .../barlow_twins_custom/requirements.txt | 4 ++ .../models/barlow_twins_custom/test.py | 12 ++++ 4 files changed, 79 insertions(+) create mode 100644 brainscore_vision/models/barlow_twins_custom/__init__.py create mode 100644 brainscore_vision/models/barlow_twins_custom/model.py create mode 100644 brainscore_vision/models/barlow_twins_custom/requirements.txt create mode 100644 brainscore_vision/models/barlow_twins_custom/test.py diff --git a/brainscore_vision/models/barlow_twins_custom/__init__.py b/brainscore_vision/models/barlow_twins_custom/__init__.py new file mode 100644 index 000000000..e47ab5797 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/__init__.py @@ -0,0 +1,5 @@ +from brainscore_vision import model_registry +from .model import get_model + +# Register the Barlow Twins model with custom weights +model_registry['barlow_twins_custom'] = lambda: get_model('barlow_twins_custom') diff --git a/brainscore_vision/models/barlow_twins_custom/model.py b/brainscore_vision/models/barlow_twins_custom/model.py new file mode 100644 index 000000000..17d431d39 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/model.py @@ -0,0 +1,58 @@ +import torch +from pathlib import Path +from torchvision.models import resnet18 +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images +from collections import OrderedDict +from urllib.request import urlretrieve +import functools +import os + + +# Custom model loader +def get_model(name): + assert name == 'barlow_twins_custom' + url = " https://www.dropbox.com/scl/fi/c6b940qscjb43xhgda9om/barlow_twins-custom_dataset_3-685qxt9j-ep-399.ckpt?rlkey=poq82f01jen6u3t005689ge93&st=4u6t330l&dl=1" + fh, _ = urlretrieve(url) + print(f"Downloaded weights file: {fh}, Size: {os.path.getsize(fh)} bytes") + + checkpoint = torch.load(fh, map_location="cpu") + state_dict = checkpoint['state_dict'] + + backbone_state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if not k.startswith("projector.")} + model = resnet18(pretrained=False) + model.load_state_dict(backbone_state_dict, strict=False) + + + preprocessing = functools.partial(load_preprocess_images, image_size=224) + + activations_model = PytorchWrapper(identifier='barlow_twins_custom', model=model, preprocessing=preprocessing) + + + return ModelCommitment( + identifier='barlow_twins_custom', + activations_model=activations_model, + layers=['layer1', 'layer2', 'layer3', 'layer4', 'avgpool'] + ) + +def get_model_list(): + return ['barlow_twins_custom'] + +# Specify layers to test +def get_layers(name): + assert name == 'barlow_twins_custom' + return ['layer1', 'layer2', 'layer3', 'layer4', 'avgpool'] + +def get_bibtex(model_identifier): + return """ + @misc{resnet18_test_consistency, + title={ArtResNet18 Barlow Twins}, + author={Claudia Noche}, + year={2024}, + } + """ + +if __name__ == '__main__': + from brainscore_vision.model_helpers.check_submission import check_models + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/barlow_twins_custom/requirements.txt b/brainscore_vision/models/barlow_twins_custom/requirements.txt new file mode 100644 index 000000000..bd006e518 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/requirements.txt @@ -0,0 +1,4 @@ +torch +torchvision +requests +pathlib diff --git a/brainscore_vision/models/barlow_twins_custom/test.py b/brainscore_vision/models/barlow_twins_custom/test.py new file mode 100644 index 000000000..d9576c644 --- /dev/null +++ b/brainscore_vision/models/barlow_twins_custom/test.py @@ -0,0 +1,12 @@ +import pytest +import brainscore_vision + +@pytest.mark.travis_slow +def test_barlow_twins_custom(): + model = brainscore_vision.load_model('barlow_twins_custom') + assert model.identifier == 'barlow_twins_custom' + + + +# AssertionError: No registrations found for resnet18_random +# ⚡ master ~/vision python -m brainscore_vision score --model_identifier='resnet50_tutorial' --benchmark_identifier='MajajHong2015public.IT-pls' \ No newline at end of file