From 394bcee06ad34b57d86652ba2879b09adc71a660 Mon Sep 17 00:00:00 2001 From: Kartik Pradeepan Date: Thu, 18 Jul 2024 10:59:40 -0400 Subject: [PATCH] Add mobilevit_small - take 2 (#1051) * Add mobilevit_small - take 2 * Update brainscore_vision/models/mobilevit_small/model.py Co-authored-by: Martin Schrimpf --------- Co-authored-by: Martin Schrimpf Co-authored-by: Michael Ferguson --- .../models/mobilevit_small/__init__.py | 7 +++ .../models/mobilevit_small/model.py | 49 +++++++++++++++++++ .../models/mobilevit_small/requirements.txt | 3 ++ .../models/mobilevit_small/test.py | 8 +++ 4 files changed, 67 insertions(+) create mode 100644 brainscore_vision/models/mobilevit_small/__init__.py create mode 100644 brainscore_vision/models/mobilevit_small/model.py create mode 100644 brainscore_vision/models/mobilevit_small/requirements.txt create mode 100644 brainscore_vision/models/mobilevit_small/test.py diff --git a/brainscore_vision/models/mobilevit_small/__init__.py b/brainscore_vision/models/mobilevit_small/__init__.py new file mode 100644 index 000000000..f0a457c5c --- /dev/null +++ b/brainscore_vision/models/mobilevit_small/__init__.py @@ -0,0 +1,7 @@ +from brainscore_vision import model_registry +from brainscore_vision.model_helpers.brain_transformation import ModelCommitment +from .model import get_model, get_layers + +model_registry['mobilevit_small'] = lambda: ModelCommitment(identifier='mobilevit_small', + activations_model=get_model('mobilevit_small'), + layers=get_layers('mobilevit_small')) \ No newline at end of file diff --git a/brainscore_vision/models/mobilevit_small/model.py b/brainscore_vision/models/mobilevit_small/model.py new file mode 100644 index 000000000..f5c9324f2 --- /dev/null +++ b/brainscore_vision/models/mobilevit_small/model.py @@ -0,0 +1,49 @@ +from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper +from brainscore_vision.model_helpers.activations.pytorch import load_images, load_preprocess_images +from brainscore_vision.model_helpers.check_submission import check_models +import ssl +from transformers import MobileViTForImageClassification +import functools + +ssl._create_default_https_context = ssl._create_unverified_context + +''' +Can be found on huggingface: https://huggingface.co/apple/mobilevit-small +''' + +def get_model(name): + assert name == 'mobilevit_small' + model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-small") + preprocessing = functools.partial(load_preprocess_images, image_size=256) + wrapper = PytorchWrapper(identifier='mobilevit_small', model=model, + preprocessing=preprocessing, + batch_size=4) + wrapper.image_size = 256 + return wrapper + + +def get_layers(name): + assert name == 'mobilevit_small' + layer_names = ["mobilevit.encoder.layer.0", "mobilevit.encoder.layer.1", "mobilevit.encoder.layer.2", + "mobilevit.encoder.layer.2.fusion.activation", + "mobilevit.encoder.layer.3", "mobilevit.encoder.layer.3.fusion.activation", + "mobilevit.encoder.layer.4", "mobilevit.encoder.layer.4.fusion.activation"] + + return layer_names + + +def get_bibtex(model_identifier): + """ + A method returning the bibtex reference of the requested model as a string. + """ + return """@inproceedings{vision-transformer, + title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer}, + author = {Sachin Mehta and Mohammad Rastegari}, + year = {2022}, + URL = {https://arxiv.org/abs/2110.02178} + } + """ + + +if __name__ == '__main__': + check_models.check_base_models(__name__) \ No newline at end of file diff --git a/brainscore_vision/models/mobilevit_small/requirements.txt b/brainscore_vision/models/mobilevit_small/requirements.txt new file mode 100644 index 000000000..1d28db41f --- /dev/null +++ b/brainscore_vision/models/mobilevit_small/requirements.txt @@ -0,0 +1,3 @@ +torchvision +torch +transformers \ No newline at end of file diff --git a/brainscore_vision/models/mobilevit_small/test.py b/brainscore_vision/models/mobilevit_small/test.py new file mode 100644 index 000000000..b41ed0480 --- /dev/null +++ b/brainscore_vision/models/mobilevit_small/test.py @@ -0,0 +1,8 @@ +import pytest +import brainscore_vision + + +@pytest.mark.travis_slow +def test_has_identifier(): + model = brainscore_vision.load_model('mobilevit_small') + assert model.identifier == 'mobilevit_small' \ No newline at end of file