diff --git a/brainscore_vision/models/resnet50_wo_shadows_4/__init__.py b/brainscore_vision/models/resnet50_wo_shadows_4/__init__.py new file mode 100644 index 000000000..99aecf14a --- /dev/null +++ b/brainscore_vision/models/resnet50_wo_shadows_4/__init__.py @@ -0,0 +1,6 @@ + +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry['resnet50_wo_shadows_iteration=4'] = lambda: ModelCommitment(identifier='resnet50_wo_shadows_iteration=4', activations_model=get_model('resnet50_wo_shadows_iteration=4'), layers=get_layers('resnet50_wo_shadows_iteration=4')) diff --git a/brainscore_vision/models/resnet50_wo_shadows_4/model.py b/brainscore_vision/models/resnet50_wo_shadows_4/model.py new file mode 100644 index 000000000..216547b5f --- /dev/null +++ b/brainscore_vision/models/resnet50_wo_shadows_4/model.py @@ -0,0 +1,200 @@ + +from brainscore_vision.model_helpers.check_submission import check_models +import functools +import numpy as np +import torch +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from PIL import Image +from torch import nn +import pytorch_lightning as pl +import torchvision.models as models +import gdown +import glob +import os +from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images + +def get_bibtex(model_identifier): + return 'VGG16' + +def get_model_list(): + return ['resnet50_wo_shadows_iteration=4'] + +def get_model(name): + keyword = 'wo_shadows' + iteration = 4 + network = 'resnet50' + url = 'https://eggerbernhard.ch/shreya/latest_resnet50/wo_shadows_4.ckpt' + output = 'resnet50_wo_shadows_iteration=4.ckpt' + gdown.download(url, output) + + + if keyword != 'imagenet_trained' and keyword != 'no_training': + lx_whole = [f"resnet50_wo_shadows_iteration=4.ckpt"] + if len(lx_whole) > 1: + lx_whole = [lx_whole[-1]] + elif keyword == 'imagenet_trained' or keyword == 'no_training': + print('keyword is imagenet') + lx_whole = ['x'] + + for model_ckpt in lx_whole: + print(model_ckpt) + last_module_name = None + last_module = None + layers = [] + if keyword == 'imagenet_trained' and network != 'clip': + model = torch.hub.load('pytorch/vision', network, pretrained=True) + for name, module in model.named_modules(): + last_module_name = name + last_module = module + layers.append(name) + else: + model = torch.hub.load('pytorch/vision', network, pretrained=False) + if model_ckpt != 'x': + ckpt = torch.load(model_ckpt, map_location='cpu') + if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + # Add more cases for other networks as needed + assert name == 'resnet50_wo_shadows_iteration=4' + url = 'https://eggerbernhard.ch/shreya/latest_resnet50/wo_shadows_4.ckpt' + output = 'resnet50_wo_shadows_iteration=4.ckpt' + gdown.download(url, output) + layers = [] + for name, module in model._modules.items(): + print(name, "->", module) + layers.append(name) + + preprocessing = functools.partial(load_preprocess_images, image_size=224) + activations_model = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing) + + return activations_model + +def get_layers(name): + keyword = 'wo_shadows' + iteration = 4 + network = 'resnet50' + url = 'https://eggerbernhard.ch/shreya/latest_resnet50/wo_shadows_4.ckpt' + output = 'resnet50_wo_shadows_iteration=4.ckpt' + gdown.download(url, output) + + + if keyword != 'imagenet_trained' and keyword != 'no_training': + lx_whole = [f"resnet50_wo_shadows_iteration=4.ckpt"] + if len(lx_whole) > 1: + lx_whole = [lx_whole[-1]] + elif keyword == 'imagenet_trained' or keyword == 'no_training': + print('keyword is imagenet') + lx_whole = ['x'] + + + for model_ckpt in lx_whole: + print(model_ckpt) + last_module_name = None + last_module = None + if keyword == 'imagenet_trained' and network != 'clip': + model = torch.hub.load('pytorch/vision', network, pretrained=True) + for name, module in model.named_modules(): + last_module_name = name + last_module = module + layers.append(name) + else: + model = torch.hub.load('pytorch/vision', network, pretrained=False) + if model_ckpt != 'x': + ckpt = torch.load(model_ckpt, map_location='cpu') + if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + # Add more cases for other networks as needed + layers = [] + for name, module in model._modules.items(): + print(name, "->", module) + layers.append(name) + return layers + +if __name__ == '__main__': + device = "cpu" + global model + global keyword + global network + global iteration + keyword = 'wo_shadows' + iteration = 4 + network = 'resnet50' + url = 'https://eggerbernhard.ch/shreya/latest_resnet50/wo_shadows_4.ckpt' + output = 'resnet50_wo_shadows_iteration=4.ckpt' + gdown.download(url, output) + + + if keyword != 'imagenet_trained' and keyword != 'no_training': + lx_whole = [f"resnet50_wo_shadows_iteration=4.ckpt"] + if len(lx_whole) > 1: + lx_whole = [lx_whole[-1]] + elif keyword == 'imagenet_trained' or keyword == 'no_training': + print('keyword is imagenet') + lx_whole = ['x'] + + for model_ckpt in lx_whole: + print(model_ckpt) + last_module_name = None + last_module = None + layers = [] + if keyword == 'imagenet_trained' and network != 'clip': + model = torch.hub.load('pytorch/vision', network, pretrained=True) + for name, module in model.named_modules(): + last_module_name = name + last_module = module + layers.append(name) + else: + model = torch.hub.load('pytorch/vision', network, pretrained=False) + if model_ckpt != 'x': + ckpt = torch.load(model_ckpt, map_location='cpu') + if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + # Add more cases for other networks as needed + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/resnet50_wo_shadows_4/setup.py b/brainscore_vision/models/resnet50_wo_shadows_4/setup.py new file mode 100644 index 000000000..64f80c7d6 --- /dev/null +++ b/brainscore_vision/models/resnet50_wo_shadows_4/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ + "torchvision", + "torch", + "gdown", + "pytorch_lightning", + "brainscore_vision" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/resnet50_wo_shadows_4/test.py b/brainscore_vision/models/resnet50_wo_shadows_4/test.py new file mode 100644 index 000000000..d03a9a5bd --- /dev/null +++ b/brainscore_vision/models/resnet50_wo_shadows_4/test.py @@ -0,0 +1,3 @@ + +import pytest +