Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/brain-score/vision
Browse files Browse the repository at this point in the history
  • Loading branch information
YingtianDt committed May 14, 2024
2 parents 85f5217 + b04b329 commit faf7fdd
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
17 changes: 11 additions & 6 deletions brainscore_vision/benchmark_helpers/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import shutil
from typing import Union

import numpy as np
from PIL import Image
Expand All @@ -19,7 +20,9 @@
_logger = logging.getLogger(__name__)


def place_on_screen(stimulus_set: StimulusSet, target_visual_degrees: int, source_visual_degrees: int = None):
def place_on_screen(stimulus_set: StimulusSet,
target_visual_degrees: Union[int, float],
source_visual_degrees: Union[int, float, None] = None):
_logger.debug(f"Converting {stimulus_set.identifier} to {target_visual_degrees} degrees")

assert source_visual_degrees or 'degrees' in stimulus_set, \
Expand All @@ -43,14 +46,16 @@ def _determine_visual_degrees(visual_degrees, stimulus_set):

@store(identifier_ignore=['stimulus_set'])
def _place_on_screen(stimuli_identifier: str, stimulus_set: StimulusSet,
target_visual_degrees: int, source_visual_degrees: int = None):
converted_stimuli_id = f"{stimuli_identifier}--target{target_visual_degrees}--source{source_visual_degrees}"
target_visual_degrees: Union[int, float], source_visual_degrees: Union[int, float, None] = None):
source_degrees_formatted = f"{source_visual_degrees}" if source_visual_degrees is None \
else f"{source_visual_degrees:.2f}" # make sure we do not try to print a None with 2 decimal places
converted_stimuli_id = f"{stimuli_identifier}--target{target_visual_degrees:.2f}--source{source_degrees_formatted}"
source_visual_degrees = _determine_visual_degrees(source_visual_degrees, stimulus_set)

target_dir = root_path / converted_stimuli_id
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
target_dir.mkdir(parents=True, exist_ok=False)
target_dir.mkdir(parents=True, exist_ok=False)
image_converter = ImageConverter(target_dir=target_dir)

converted_image_paths = {}
Expand All @@ -72,7 +77,7 @@ class ImageConverter:
def __init__(self, target_dir):
self._target_dir = Path(target_dir)

def convert_image(self, image_path, source_degrees, target_degrees):
def convert_image(self, image_path, source_degrees: Union[int, float], target_degrees: Union[int, float]):
if source_degrees == target_degrees:
return image_path
ratio = target_degrees / source_degrees
Expand Down Expand Up @@ -118,4 +123,4 @@ def _center_on_background(self, center_image, background_size, background_color=
return image

def _write(self, image, target_path):
image.save(target_path)
image.save(target_path)
17 changes: 17 additions & 0 deletions brainscore_vision/model_helpers/brain_transformation/behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,25 @@ class LabelToImagenetIndices:
motorbike_indices = [670, 665]
bus_indices = [779, 874, 654]

# added from the Scialom2024 benchmark:
banana_indices = [954]
beanie_indices = [439, 452, 515, 808]
binoculars_indices = [447]
boot_indices = [514]
bowl_indices = [659, 809]
cup_indices = [968]
glasses_indices = [837]
lamp_indices = [470, 607, 818, 846]
pan_indices = [567]
sewingmachine_indices = [786]
shovel_indices = [792]
# truck indices used as defined by Geirhos et al., 2021.

@classmethod
def label_to_indices(cls, label):
# for handling multi-word labels given by models or benchmarks
label = label.lower().replace(" ", "")

synset_indices = getattr(cls, f"{label}_indices")
return synset_indices

Expand Down
6 changes: 2 additions & 4 deletions brainscore_vision/models/resnet18_imagenet21kP/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from brainscore_vision import model_registry
from .model import get_model

MODEL_NAME = "resnet18_imagenet21kP"

model_registry["[email protected]"] = lambda: get_model(
MODEL_NAME
model_registry["resnet18_imagenet21kP"] = lambda: get_model(
"resnet18_imagenet21kP"
)

0 comments on commit faf7fdd

Please sign in to comment.