Skip to content

Commit

Permalink
add skynet_srouce_code to models
Browse files Browse the repository at this point in the history
  • Loading branch information
Jenkins committed Sep 23, 2024
1 parent d02f31c commit 16a2ce9
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 0 deletions.
5 changes: 5 additions & 0 deletions brainscore_vision/models/skynet_srouce_code/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['skynet_source_code'] = lambda: ModelCommitment(identifier='skynet_source_code', activations_model=get_model('skynet_source_code'), layers=get_layers('skynet_source_code'))
62 changes: 62 additions & 0 deletions brainscore_vision/models/skynet_srouce_code/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from brainscore_vision.model_helpers.check_submission import check_models
import functools
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images
import torch
import numpy as np
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment


# This is an example implementation for submitting custom model named my_custom_model

# Attention: It is important, that the wrapper identifier is unique per model!
# The results will otherwise be the same due to brain-scores internal result caching mechanism.
# Please load your pytorch model for usage in CPU. There won't be GPUs available for scoring your model.
# If the model requires a GPU, contact the brain-score team directly.


class MyCustomModel(torch.nn.Module):
def __init__(self):
super(MyCustomModel, self).__init__()
self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3)
self.relu1 = torch.nn.ReLU()
linear_input_size = np.power((224 - 3 + 2 * 0) / 1 + 1, 2) * 2
self.linear = torch.nn.Linear(int(linear_input_size), 1000)
self.relu2 = torch.nn.ReLU() # can't get named ReLU output otherwise

def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
x = self.relu2(x)
return x


def get_model_list():
return ['skynet_source_code']


def get_model(name):
assert name == 'skynet_source_code'
preprocessing = functools.partial(load_preprocess_images, image_size=224)
activations_model = PytorchWrapper(identifier='skynet_source_code', model=MyCustomModel(), preprocessing=preprocessing)
model = ModelCommitment(identifier='skynet_source_code', activations_model=activations_model,
# specify layers to consider
layers=['conv1', 'relu1', 'relu2'])
wrapper = PytorchWrapper(identifier='skynet_source_code', model=model, preprocessing=preprocessing)
wrapper.image_size = 224
return wrapper


def get_layers(name):
assert name == 'skynet_source_code'
return ['conv1', 'relu1', 'relu2']


def get_bibtex(model_identifier):
return """xx"""


if __name__ == '__main__':
check_models.check_base_models(__name__)
25 changes: 25 additions & 0 deletions brainscore_vision/models/skynet_srouce_code/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from setuptools import setup, find_packages

requirements = [ "torchvision",
"torch"
]

setup(
packages=find_packages(exclude=['tests']),
include_package_data=True,
install_requires=requirements,
license="MIT license",
zip_safe=False,
keywords='brain-score template',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python :: 3.7',
],
test_suite='tests',
)
1 change: 1 addition & 0 deletions brainscore_vision/models/skynet_srouce_code/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Left empty as part of 2023 models migration

0 comments on commit 16a2ce9

Please sign in to comment.