Skip to content

Commit

Permalink
add vgg16_less_variation_1 to models
Browse files Browse the repository at this point in the history
  • Loading branch information
Jenkins committed Nov 29, 2024
1 parent 100e698 commit aa3a7e7
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
4 changes: 4 additions & 0 deletions brainscore_vision/models/vgg16_less_variation_1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
model_registry['vgg16_less_variation_iteration=1'] = lambda: ModelCommitment(identifier='vgg16_less_variation_iteration=1')
97 changes: 97 additions & 0 deletions brainscore_vision/models/vgg16_less_variation_1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from brainscore_vision.model_helpers.check_submission import check_models
import functools
import numpy as np
import torch
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from PIL import Image
import open_clip
from torch import nn
import pytorch_lightning as pl
import torchvision.models as models
import gdown
import glob
import os

device = "cpu"
keyword = 'less_variation'
iteration = 1
network = 'vgg16'
url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt"
output = f"vgg16_less_variation_iteration=1.ckpt"
gdown.download(url, output)

if keyword != 'imagenet_trained' and keyword != 'no_training':
lx_whole = list(glob.glob(f"vgg16_less_variation_iteration=1.ckpt"))
print(lx_whole)
if len(lx_whole) == 0:
continue
if len(lx_whole) > 1:
lx_whole = [lx_whole[-1]]
elif keyword == 'imagenet_trained' or keyword == 'no_training':
print('keyword is imagenet')
lx_whole = ['x']

for model_ckpt in lx_whole:
print(model_ckpt)
last_module_name = None
last_module = None
layers = []
if keyword == 'imagenet_trained' and network != 'clip':
model = torch.hub.load('pytorch/vision', network, pretrained=True)
for name, module in model.named_modules():
last_module_name = name
last_module = module
layers.append(name)
else:
model = torch.hub.load('pytorch/vision', network, pretrained=False)
if model_ckpt != 'x':
ckpt = torch.load(model_ckpt, map_location='cpu')
if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
# Add more cases for other networks as needed

def get_bibtex(model_identifier):
return "VGG16"

def get_model_list():
return [f'vgg16_less_variation_iteration=1']

def get_model(name):
assert name == f'vgg16_less_variation_iteration=1'
url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt"
output = f"vgg16_less_variation_iteration=1.ckpt"
gdown.download(url, output)

preprocessing = functools.partial(load_preprocess_images, image_size=224)
activations_model = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)

return activations_model

def get_layers(name):
assert name == f'vgg16_less_variation_iteration=1.ckpt'
layers = []
url = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt"
output = f"https://eggerbernhard.ch/shreya/vgg16_less_variation_iteration=1.ckpt"
gdown.download(url, output)
for name, module in model.named_modules():
layers.append(name)
return layers

if __name__ == '__main__':
check_models.check_base_models(__name__)
29 changes: 29 additions & 0 deletions brainscore_vision/models/vgg16_less_variation_1/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from setuptools import setup, find_packages

requirements = [
"torchvision",
"torch",
"gdown",
"pytorch_lightning",
"open_clip",
]

setup(
packages=find_packages(exclude=['tests']),
include_package_data=True,
install_requires=requirements,
license="MIT license",
zip_safe=False,
keywords='brain-score template',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python :: 3.7',
],
test_suite='tests',
)
1 change: 1 addition & 0 deletions brainscore_vision/models/vgg16_less_variation_1/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
import pytest

0 comments on commit aa3a7e7

Please sign in to comment.