Skip to content

Commit

Permalink
add alexnet_wo_shadows_5 to models
Browse files Browse the repository at this point in the history
  • Loading branch information
Jenkins committed Dec 1, 2024
1 parent 6c208b7 commit 900d9c3
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 0 deletions.
6 changes: 6 additions & 0 deletions brainscore_vision/models/alexnet_wo_shadows_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['alexnet_wo_shadows_iteration=5'] = lambda: ModelCommitment(identifier='alexnet_wo_shadows_iteration=5', activations_model=get_model('alexnet_wo_shadows_iteration=5'), layers=get_layers('alexnet_wo_shadows_iteration=5'))
200 changes: 200 additions & 0 deletions brainscore_vision/models/alexnet_wo_shadows_5/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@

from brainscore_vision.model_helpers.check_submission import check_models
import functools
import numpy as np
import torch
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from PIL import Image
from torch import nn
import pytorch_lightning as pl
import torchvision.models as models
import gdown
import glob
import os
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images

def get_bibtex(model_identifier):
return 'VGG16'

def get_model_list():
return ['alexnet_wo_shadows_iteration=5']

def get_model(name):
keyword = 'wo_shadows'
iteration = 5
network = 'alexnet'
url = 'https://eggerbernhard.ch/shreya/latest_alexnet/wo_shadows_5.ckpt'
output = 'alexnet_wo_shadows_iteration=5.ckpt'
gdown.download(url, output)


if keyword != 'imagenet_trained' and keyword != 'no_training':
lx_whole = [f"alexnet_wo_shadows_iteration=5.ckpt"]
if len(lx_whole) > 1:
lx_whole = [lx_whole[-1]]
elif keyword == 'imagenet_trained' or keyword == 'no_training':
print('keyword is imagenet')
lx_whole = ['x']

for model_ckpt in lx_whole:
print(model_ckpt)
last_module_name = None
last_module = None
layers = []
if keyword == 'imagenet_trained' and network != 'clip':
model = torch.hub.load('pytorch/vision', network, pretrained=True)
for name, module in model.named_modules():
last_module_name = name
last_module = module
layers.append(name)
else:
model = torch.hub.load('pytorch/vision', network, pretrained=False)
if model_ckpt != 'x':
ckpt = torch.load(model_ckpt, map_location='cpu')
if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
# Add more cases for other networks as needed
assert name == 'alexnet_wo_shadows_iteration=5'
url = 'https://eggerbernhard.ch/shreya/latest_alexnet/wo_shadows_5.ckpt'
output = 'alexnet_wo_shadows_iteration=5.ckpt'
gdown.download(url, output)
layers = []
for name, module in model._modules.items():
print(name, "->", module)
layers.append(name)

preprocessing = functools.partial(load_preprocess_images, image_size=224)
activations_model = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)

return activations_model

def get_layers(name):
keyword = 'wo_shadows'
iteration = 5
network = 'alexnet'
url = 'https://eggerbernhard.ch/shreya/latest_alexnet/wo_shadows_5.ckpt'
output = 'alexnet_wo_shadows_iteration=5.ckpt'
gdown.download(url, output)


if keyword != 'imagenet_trained' and keyword != 'no_training':
lx_whole = [f"alexnet_wo_shadows_iteration=5.ckpt"]
if len(lx_whole) > 1:
lx_whole = [lx_whole[-1]]
elif keyword == 'imagenet_trained' or keyword == 'no_training':
print('keyword is imagenet')
lx_whole = ['x']


for model_ckpt in lx_whole:
print(model_ckpt)
last_module_name = None
last_module = None
if keyword == 'imagenet_trained' and network != 'clip':
model = torch.hub.load('pytorch/vision', network, pretrained=True)
for name, module in model.named_modules():
last_module_name = name
last_module = module
layers.append(name)
else:
model = torch.hub.load('pytorch/vision', network, pretrained=False)
if model_ckpt != 'x':
ckpt = torch.load(model_ckpt, map_location='cpu')
if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
# Add more cases for other networks as needed
layers = []
for name, module in model._modules.items():
print(name, "->", module)
layers.append(name)
return layers

if __name__ == '__main__':
device = "cpu"
global model
global keyword
global network
global iteration
keyword = 'wo_shadows'
iteration = 5
network = 'alexnet'
url = 'https://eggerbernhard.ch/shreya/latest_alexnet/wo_shadows_5.ckpt'
output = 'alexnet_wo_shadows_iteration=5.ckpt'
gdown.download(url, output)


if keyword != 'imagenet_trained' and keyword != 'no_training':
lx_whole = [f"alexnet_wo_shadows_iteration=5.ckpt"]
if len(lx_whole) > 1:
lx_whole = [lx_whole[-1]]
elif keyword == 'imagenet_trained' or keyword == 'no_training':
print('keyword is imagenet')
lx_whole = ['x']

for model_ckpt in lx_whole:
print(model_ckpt)
last_module_name = None
last_module = None
layers = []
if keyword == 'imagenet_trained' and network != 'clip':
model = torch.hub.load('pytorch/vision', network, pretrained=True)
for name, module in model.named_modules():
last_module_name = name
last_module = module
layers.append(name)
else:
model = torch.hub.load('pytorch/vision', network, pretrained=False)
if model_ckpt != 'x':
ckpt = torch.load(model_ckpt, map_location='cpu')
if model_ckpt != 'x' and network == 'alexnet' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
if model_ckpt != 'x' and network == 'vgg16' and keyword != 'imagenet_trained':
ckpt2 = {}
for keys in ckpt['state_dict']:
print(keys)
print(ckpt['state_dict'][keys].shape)
print('---')
k2 = keys.split('model.')[1]
ckpt2[k2] = ckpt['state_dict'][keys]
model.load_state_dict(ckpt2)
# Add more cases for other networks as needed
check_models.check_base_models(__name__)
29 changes: 29 additions & 0 deletions brainscore_vision/models/alexnet_wo_shadows_5/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from setuptools import setup, find_packages

requirements = [
"torchvision",
"torch",
"gdown",
"pytorch_lightning",
"brainscore_vision"
]

setup(
packages=find_packages(exclude=['tests']),
include_package_data=True,
install_requires=requirements,
license="MIT license",
zip_safe=False,
keywords='brain-score template',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python :: 3.7',
],
test_suite='tests',
)
3 changes: 3 additions & 0 deletions brainscore_vision/models/alexnet_wo_shadows_5/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

import pytest

0 comments on commit 900d9c3

Please sign in to comment.