-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jenkins
committed
Nov 5, 2024
1 parent
7af75c3
commit c50c328
Showing
4 changed files
with
155 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from .model import get_model, get_layers | ||
|
||
|
||
model_registry['cvt_cvt-13-224-in1k_4'] = \ | ||
lambda: ModelCommitment(identifier='cvt_cvt-13-224-in1k_4', | ||
activations_model=get_model('cvt_cvt-13-224-in1k_4'), | ||
layers=get_layers('cvt_cvt-13-224-in1k_4')) |
134 changes: 134 additions & 0 deletions
134
brainscore_vision/models/cvt_cvt_13_224_in1k_4/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from brainscore_vision.model_helpers.check_submission import check_models | ||
import functools | ||
from transformers import AutoFeatureExtractor, CvtForImageClassification | ||
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper | ||
from PIL import Image | ||
import numpy as np | ||
import torch | ||
|
||
""" | ||
Template module for a base model submission to brain-score | ||
""" | ||
|
||
|
||
def get_model(name): | ||
assert name == 'cvt_cvt-13-224-in1k_4' | ||
# https://huggingface.co/models?sort=downloads&search=cvt | ||
image_size = 224 | ||
processor = AutoFeatureExtractor.from_pretrained('microsoft/cvt-13') | ||
model = CvtForImageClassification.from_pretrained('microsoft/cvt-13') | ||
preprocessing = functools.partial(load_preprocess_images, processor=processor, image_size=image_size) | ||
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing) | ||
wrapper.image_size = image_size | ||
|
||
return wrapper | ||
|
||
|
||
def get_layers(name): | ||
assert name == 'cvt_cvt-13-224-in1k_4' | ||
layers = [] | ||
layers += [f'cvt.encoder.stages.0.layers.{i}' for i in range(1)] | ||
layers += [f'cvt.encoder.stages.1.layers.{i}' for i in range(2)] | ||
layers += [f'cvt.encoder.stages.2.layers.{i}' for i in range(10)] | ||
layers += ['layernorm'] | ||
return layers | ||
|
||
|
||
def get_bibtex(model_identifier): | ||
""" | ||
A method returning the bibtex reference of the requested model as a string. | ||
""" | ||
return '' | ||
|
||
|
||
def load_preprocess_images(image_filepaths, image_size, processor=None, **kwargs): | ||
images = load_images(image_filepaths) | ||
# images = [<PIL.Image.Image image mode=RGB size=400x400 at 0x7F8654B2AC10>, ...] | ||
images = [image.resize((image_size, image_size)) for image in images] | ||
if processor is not None: | ||
images = [processor(images=image, return_tensors="pt", **kwargs) for image in images] | ||
if len(images[0].keys()) != 1: | ||
raise NotImplementedError(f'unknown processor for getting model {processor}') | ||
assert list(images[0].keys())[0] == 'pixel_values' | ||
images = [image['pixel_values'] for image in images] | ||
images = torch.cat(images) | ||
images = images.cpu().numpy() | ||
else: | ||
images = preprocess_images(images, image_size=image_size, **kwargs) | ||
return images | ||
|
||
|
||
def load_images(image_filepaths): | ||
return [load_image(image_filepath) for image_filepath in image_filepaths] | ||
|
||
|
||
def load_image(image_filepath): | ||
with Image.open(image_filepath) as pil_image: | ||
if 'L' not in pil_image.mode.upper() and 'A' not in pil_image.mode.upper() \ | ||
and 'P' not in pil_image.mode.upper(): # not binary and not alpha and not palletized | ||
# work around to https://github.com/python-pillow/Pillow/issues/1144, | ||
# see https://stackoverflow.com/a/30376272/2225200 | ||
return pil_image.copy() | ||
else: # make sure potential binary images are in RGB | ||
rgb_image = Image.new("RGB", pil_image.size) | ||
rgb_image.paste(pil_image) | ||
return rgb_image | ||
|
||
|
||
def preprocess_images(images, image_size, **kwargs): | ||
preprocess = torchvision_preprocess_input(image_size, **kwargs) | ||
images = [preprocess(image) for image in images] | ||
images = np.concatenate(images) | ||
return images | ||
|
||
|
||
def torchvision_preprocess_input(image_size, **kwargs): | ||
from torchvision import transforms | ||
return transforms.Compose([ | ||
transforms.Resize((image_size, image_size)), | ||
torchvision_preprocess(**kwargs), | ||
]) | ||
|
||
|
||
def torchvision_preprocess(normalize_mean=(0.485, 0.456, 0.406), normalize_std=(0.229, 0.224, 0.225)): | ||
from torchvision import transforms | ||
return transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize(mean=normalize_mean, std=normalize_std), | ||
lambda img: img.unsqueeze(0) | ||
]) | ||
|
||
|
||
def create_static_video(image, num_frames, normalize_0to1=False, channel_dim=3): | ||
''' | ||
Create a static video with the same image in all frames. | ||
Args: | ||
image (PIL.Image.Image): Input image. | ||
num_frames (int): Number of frames in the video. | ||
Returns: | ||
result (np.ndarray): np array of frames of shape (num_frames, height, width, 3). | ||
''' | ||
frames = [] | ||
for _ in range(num_frames): | ||
frame = np.array(image) | ||
if normalize_0to1: | ||
frame = frame / 255. | ||
if channel_dim == 1: | ||
frame = frame.transpose(2, 0, 1) | ||
frames.append(frame) | ||
return np.stack(frames) | ||
|
||
|
||
if __name__ == '__main__': | ||
# Use this method to ensure the correctness of the BaseModel implementations. | ||
# It executes a mock run of brain-score benchmarks. | ||
check_models.check_base_models(__name__) | ||
|
||
""" | ||
Notes on the error: | ||
- 'channel_x' key error: | ||
# 'embeddings.patch_embeddings.projection', | ||
https://github.com/search?q=repo%3Abrain-score%2Fmodel-tools%20channel_x&type=code | ||
""" |
4 changes: 4 additions & 0 deletions
4
brainscore_vision/models/cvt_cvt_13_224_in1k_4/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
numpy | ||
torch | ||
transformers==4.30.2 | ||
pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
import pytest | ||
import brainscore_vision | ||
|
||
|
||
@pytest.mark.travis_slow | ||
def test_has_identifier(): | ||
model = brainscore_vision.load_model('cvt_cvt-13-224-in1k_4') | ||
assert model.identifier == 'cvt_cvt-13-224-in1k_4' |