-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add yudixie_resnet18_240719_11 to models
- Loading branch information
AutoJenkins
committed
Jul 19, 2024
1 parent
394bcee
commit bee99c4
Showing
4 changed files
with
97 additions
and
0 deletions.
There are no files selected for viewing
11 changes: 11 additions & 0 deletions
11
brainscore_vision/models/yudixie_resnet18_240719_11/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from .model import get_model, get_layers | ||
|
||
|
||
def commit_model(identifier): | ||
return ModelCommitment(identifier=identifier, | ||
activations_model=get_model(identifier), | ||
layers=get_layers(identifier)) | ||
|
||
model_registry['yudixie_resnet18_random_0_240719'] = lambda: commit_model('yudixie_resnet18_random_0_240719') |
60 changes: 60 additions & 0 deletions
60
brainscore_vision/models/yudixie_resnet18_240719_11/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import os | ||
from pathlib import Path | ||
import functools | ||
from urllib.request import urlretrieve | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torchvision.models import resnet18 | ||
|
||
from brainscore_vision.model_helpers.check_submission import check_models | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper, load_preprocess_images | ||
|
||
|
||
# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model. | ||
# If the model requires a GPU, contact the brain-score team directly. | ||
|
||
|
||
def get_model(name): | ||
pytorch_device = torch.device('cpu') | ||
|
||
weigth_url = f'https://yudi-brainscore-models.s3.amazonaws.com/{name}.pth' | ||
fh = urlretrieve(weigth_url, f'{name}_weights.pth') | ||
load_path = fh[0] | ||
|
||
pytorch_model = resnet18() | ||
pytorch_model.fc = nn.Linear(pytorch_model.fc.in_features, 674) | ||
pytorch_model = pytorch_model.to(pytorch_device) | ||
|
||
# load model from saved weights | ||
saved_state_dict = torch.load(load_path, map_location=pytorch_device) | ||
state_dict = {} | ||
for k, v in saved_state_dict.items(): | ||
if k.startswith('_orig_mod.'): | ||
# for compiled models | ||
state_dict[k[10:]] = v | ||
else: | ||
state_dict[k] = v | ||
pytorch_model.load_state_dict(state_dict, strict=True) | ||
print(f'Loaded model from {load_path}') | ||
|
||
preprocessing = functools.partial(load_preprocess_images, image_size=224) | ||
wrapper = PytorchWrapper(identifier=name, | ||
model=pytorch_model, | ||
preprocessing=preprocessing) | ||
wrapper.image_size = 224 | ||
return wrapper | ||
|
||
|
||
def get_layers(name): | ||
return ['conv1','layer1', 'layer2', 'layer3', 'layer4', 'fc'] | ||
|
||
|
||
def get_bibtex(model_identifier): | ||
return """xx""" | ||
|
||
|
||
if __name__ == '__main__': | ||
check_models.check_base_models(__name__) |
25 changes: 25 additions & 0 deletions
25
brainscore_vision/models/yudixie_resnet18_240719_11/setup.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
from setuptools import setup, find_packages | ||
|
||
requirements = [ "torchvision", | ||
"torch" | ||
] | ||
|
||
setup( | ||
packages=find_packages(exclude=['tests']), | ||
include_package_data=True, | ||
install_requires=requirements, | ||
license="MIT license", | ||
zip_safe=False, | ||
keywords='brain-score template', | ||
classifiers=[ | ||
'Development Status :: 2 - Pre-Alpha', | ||
'Intended Audience :: Developers', | ||
'License :: OSI Approved :: MIT License', | ||
'Natural Language :: English', | ||
'Programming Language :: Python :: 3.7', | ||
], | ||
test_suite='tests', | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Left empty as part of 2023 models migration |