Skip to content

Commit

Permalink
Add mobilevit_small - take 2 (#1051)
Browse files Browse the repository at this point in the history
* Add mobilevit_small - take 2

* Update brainscore_vision/models/mobilevit_small/model.py

Co-authored-by: Martin Schrimpf <[email protected]>

---------

Co-authored-by: Martin Schrimpf <[email protected]>
Co-authored-by: Michael Ferguson <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent 149401c commit 394bcee
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
7 changes: 7 additions & 0 deletions brainscore_vision/models/mobilevit_small/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from .model import get_model, get_layers

model_registry['mobilevit_small'] = lambda: ModelCommitment(identifier='mobilevit_small',
activations_model=get_model('mobilevit_small'),
layers=get_layers('mobilevit_small'))
49 changes: 49 additions & 0 deletions brainscore_vision/models/mobilevit_small/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from brainscore_vision.model_helpers.activations.pytorch import load_images, load_preprocess_images
from brainscore_vision.model_helpers.check_submission import check_models
import ssl
from transformers import MobileViTForImageClassification
import functools

ssl._create_default_https_context = ssl._create_unverified_context

'''
Can be found on huggingface: https://huggingface.co/apple/mobilevit-small
'''

def get_model(name):
assert name == 'mobilevit_small'
model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-small")
preprocessing = functools.partial(load_preprocess_images, image_size=256)
wrapper = PytorchWrapper(identifier='mobilevit_small', model=model,
preprocessing=preprocessing,
batch_size=4)
wrapper.image_size = 256
return wrapper


def get_layers(name):
assert name == 'mobilevit_small'
layer_names = ["mobilevit.encoder.layer.0", "mobilevit.encoder.layer.1", "mobilevit.encoder.layer.2",
"mobilevit.encoder.layer.2.fusion.activation",
"mobilevit.encoder.layer.3", "mobilevit.encoder.layer.3.fusion.activation",
"mobilevit.encoder.layer.4", "mobilevit.encoder.layer.4.fusion.activation"]

return layer_names


def get_bibtex(model_identifier):
"""
A method returning the bibtex reference of the requested model as a string.
"""
return """@inproceedings{vision-transformer,
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
author = {Sachin Mehta and Mohammad Rastegari},
year = {2022},
URL = {https://arxiv.org/abs/2110.02178}
}
"""


if __name__ == '__main__':
check_models.check_base_models(__name__)
3 changes: 3 additions & 0 deletions brainscore_vision/models/mobilevit_small/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torchvision
torch
transformers
8 changes: 8 additions & 0 deletions brainscore_vision/models/mobilevit_small/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
import brainscore_vision


@pytest.mark.travis_slow
def test_has_identifier():
model = brainscore_vision.load_model('mobilevit_small')
assert model.identifier == 'mobilevit_small'

0 comments on commit 394bcee

Please sign in to comment.