-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
brain-score.org submission (user:413) | (public:False) (#1077)
* add yudixie_resnet18_240719_8 to models * Add yudixie_resnet18_object_class_0_240719.json to region_layer_map for model yudixie_resnet18_object_class_0_240719 --------- Co-authored-by: AutoJenkins <[email protected]> Co-authored-by: Kartik Pradeepan <[email protected]> Co-authored-by: KartikP <[email protected]>
- Loading branch information
1 parent
d543308
commit ab2e2ad
Showing
5 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
brainscore_vision/models/yudixie_resnet18_240719_8/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from .model import get_model, get_layers | ||
|
||
|
||
def commit_model(identifier): | ||
return ModelCommitment(identifier=identifier, | ||
activations_model=get_model(identifier), | ||
layers=get_layers(identifier)) | ||
|
||
model_registry['yudixie_resnet18_object_class_0_240719'] = lambda: commit_model('yudixie_resnet18_object_class_0_240719') |
60 changes: 60 additions & 0 deletions
60
brainscore_vision/models/yudixie_resnet18_240719_8/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from pathlib import Path | ||
import functools | ||
from urllib.request import urlretrieve | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torchvision.models import resnet18 | ||
|
||
from brainscore_vision.model_helpers.check_submission import check_models | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper, load_preprocess_images | ||
|
||
|
||
# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model. | ||
# If the model requires a GPU, contact the brain-score team directly. | ||
|
||
|
||
def get_model(name): | ||
pytorch_device = torch.device('cpu') | ||
|
||
weigth_url = f'https://yudi-brainscore-models.s3.amazonaws.com/{name}.pth' | ||
fh = urlretrieve(weigth_url, f'{name}_weights.pth') | ||
load_path = fh[0] | ||
|
||
pytorch_model = resnet18() | ||
pytorch_model.fc = nn.Linear(pytorch_model.fc.in_features, 674) | ||
pytorch_model = pytorch_model.to(pytorch_device) | ||
|
||
# load model from saved weights | ||
saved_state_dict = torch.load(load_path, map_location=pytorch_device) | ||
state_dict = {} | ||
for k, v in saved_state_dict.items(): | ||
if k.startswith('_orig_mod.'): | ||
# for compiled models | ||
state_dict[k[10:]] = v | ||
else: | ||
state_dict[k] = v | ||
pytorch_model.load_state_dict(state_dict, strict=True) | ||
print(f'Loaded model from {load_path}') | ||
|
||
preprocessing = functools.partial(load_preprocess_images, image_size=224) | ||
wrapper = PytorchWrapper(identifier=name, | ||
model=pytorch_model, | ||
preprocessing=preprocessing) | ||
wrapper.image_size = 224 | ||
return wrapper | ||
|
||
|
||
def get_layers(name): | ||
return ['conv1','layer1', 'layer2', 'layer3', 'layer4', 'fc'] | ||
|
||
|
||
def get_bibtex(model_identifier): | ||
return """xx""" | ||
|
||
|
||
if __name__ == '__main__': | ||
check_models.check_base_models(__name__) |
6 changes: 6 additions & 0 deletions
6
...ls/yudixie_resnet18_240719_8/region_layer_map/yudixie_resnet18_object_class_0_240719.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"V1": "layer2", | ||
"V2": "layer2", | ||
"V4": "layer2", | ||
"IT": "layer3" | ||
} |
25 changes: 25 additions & 0 deletions
25
brainscore_vision/models/yudixie_resnet18_240719_8/setup.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from setuptools import setup, find_packages | ||
|
||
requirements = [ "torchvision", | ||
"torch" | ||
] | ||
|
||
setup( | ||
packages=find_packages(exclude=['tests']), | ||
include_package_data=True, | ||
install_requires=requirements, | ||
license="MIT license", | ||
zip_safe=False, | ||
keywords='brain-score template', | ||
classifiers=[ | ||
'Development Status :: 2 - Pre-Alpha', | ||
'Intended Audience :: Developers', | ||
'License :: OSI Approved :: MIT License', | ||
'Natural Language :: English', | ||
'Programming Language :: Python :: 3.7', | ||
], | ||
test_suite='tests', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Left empty as part of 2023 models migration |