From d68fbad2dcf716fabbc98c715bd4721ebaaf564a Mon Sep 17 00:00:00 2001 From: Jenkins Date: Sat, 30 Nov 2024 09:15:08 -0500 Subject: [PATCH] add vgg16_less_variation_1 to models --- .../models/vgg16_less_variation_1/__init__.py | 6 ++ .../models/vgg16_less_variation_1/model.py | 93 +++++++++++++++++++ .../models/vgg16_less_variation_1/setup.py | 29 ++++++ .../models/vgg16_less_variation_1/test.py | 9 ++ 4 files changed, 137 insertions(+) create mode 100644 brainscore_vision/models/vgg16_less_variation_1/__init__.py create mode 100644 brainscore_vision/models/vgg16_less_variation_1/model.py create mode 100644 brainscore_vision/models/vgg16_less_variation_1/setup.py create mode 100644 brainscore_vision/models/vgg16_less_variation_1/test.py diff --git a/brainscore_vision/models/vgg16_less_variation_1/__init__.py b/brainscore_vision/models/vgg16_less_variation_1/__init__.py new file mode 100644 index 000000000..65ec87c9b --- /dev/null +++ b/brainscore_vision/models/vgg16_less_variation_1/__init__.py @@ -0,0 +1,6 @@ + +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry[f'vgg16_less_variation_iteration=1'] = lambda: ModelCommitment(identifier=f'vgg16_less_variation_iteration=1', activations_model=get_model(f'vgg16_less_variation_iteration=1'), layers=get_layers(f'vgg16_less_variation_iteration=1')) diff --git a/brainscore_vision/models/vgg16_less_variation_1/model.py b/brainscore_vision/models/vgg16_less_variation_1/model.py new file mode 100644 index 000000000..17cf8d0cc --- /dev/null +++ b/brainscore_vision/models/vgg16_less_variation_1/model.py @@ -0,0 +1,93 @@ +from brainscore_vision.model_helpers.check_submission import check_models +import functools +import numpy as np +import torch +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from PIL import Image +from torch import nn +import pytorch_lightning as pl +import torchvision.models as models +import gdown +import glob +import os + +device = "cpu" +keyword = 'less_variation' +iteration = 1 +network = 'vgg16' +url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt" +output = f"vgg16_less_variation_iteration=1.ckpt" +gdown.download(url, output) + +if keyword != 'imagenet_trained' and keyword != 'no_training': + lx_whole = list(f"vgg16_less_variation_iteration=1.ckpt") + if len(lx_whole) > 1: + lx_whole = [lx_whole[-1]] +elif keyword == 'imagenet_trained' or keyword == 'no_training': + print('keyword is imagenet') + lx_whole = ['x'] + +for model_ckpt in lx_whole: + print(model_ckpt) + last_module_name = None + last_module = None + layers = [] + if keyword == 'imagenet_trained' and network != 'clip': + model = torch.hub.load('pytorch/vision', network, pretrained=True) + for name, module in model.named_modules(): + last_module_name = name + last_module = module + layers.append(name) + else: + model = torch.hub.load('pytorch/vision', network, pretrained=False) + if model_ckpt != 'x': + ckpt = torch.load(model_ckpt, map_location='cpu') + if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained': + ckpt2 = {} + for keys in ckpt['state_dict']: + print(keys) + print(ckpt['state_dict'][keys].shape) + print('---') + k2 = keys.split('model.')[1] + ckpt2[k2] = ckpt['state_dict'][keys] + model.load_state_dict(ckpt2) + # Add more cases for other networks as needed + +def get_bibtex(model_identifier): + return "VGG16" + +def get_model_list(): + return [f'vgg16_less_variation_iteration=1'] + +def get_model(name): + assert name == f'vgg16_less_variation_iteration=1' + url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt" + output = f"vgg16_less_variation_iteration=1.ckpt" + gdown.download(url, output) + + preprocessing = functools.partial(load_preprocess_images, image_size=224) + activations_model = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing) + + return activations_model + +def get_layers(name): + assert name == f'vgg16_less_variation_iteration=1.ckpt' + layers = [] + url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt" + output = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt" + gdown.download(url, output) + for name, module in model.named_modules(): + layers.append(name) + return layers + +if __name__ == '__main__': + check_models.check_base_models(__name__) diff --git a/brainscore_vision/models/vgg16_less_variation_1/setup.py b/brainscore_vision/models/vgg16_less_variation_1/setup.py new file mode 100644 index 000000000..64f80c7d6 --- /dev/null +++ b/brainscore_vision/models/vgg16_less_variation_1/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +requirements = [ + "torchvision", + "torch", + "gdown", + "pytorch_lightning", + "brainscore_vision" +] + +setup( + packages=find_packages(exclude=['tests']), + include_package_data=True, + install_requires=requirements, + license="MIT license", + zip_safe=False, + keywords='brain-score template', + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 3.7', + ], + test_suite='tests', +) diff --git a/brainscore_vision/models/vgg16_less_variation_1/test.py b/brainscore_vision/models/vgg16_less_variation_1/test.py new file mode 100644 index 000000000..b10d7a7c2 --- /dev/null +++ b/brainscore_vision/models/vgg16_less_variation_1/test.py @@ -0,0 +1,9 @@ + +import pytest +import brainscore_vision + + +@pytest.mark.travis_slow +def test_has_identifier(): + model = brainscore_vision.load_model(f'vgg16_less_variation_iteration=1') + assert model.identifier == f'vgg16_less_variation_iteration=1'