diff --git a/brainscore_vision/benchmark_helpers/screen.py b/brainscore_vision/benchmark_helpers/screen.py index 0582706d4..8cd1249ed 100644 --- a/brainscore_vision/benchmark_helpers/screen.py +++ b/brainscore_vision/benchmark_helpers/screen.py @@ -5,6 +5,7 @@ import logging import os import shutil +from typing import Union import numpy as np from PIL import Image @@ -19,7 +20,9 @@ _logger = logging.getLogger(__name__) -def place_on_screen(stimulus_set: StimulusSet, target_visual_degrees: int, source_visual_degrees: int = None): +def place_on_screen(stimulus_set: StimulusSet, + target_visual_degrees: Union[int, float], + source_visual_degrees: Union[int, float, None] = None): _logger.debug(f"Converting {stimulus_set.identifier} to {target_visual_degrees} degrees") assert source_visual_degrees or 'degrees' in stimulus_set, \ @@ -43,14 +46,16 @@ def _determine_visual_degrees(visual_degrees, stimulus_set): @store(identifier_ignore=['stimulus_set']) def _place_on_screen(stimuli_identifier: str, stimulus_set: StimulusSet, - target_visual_degrees: int, source_visual_degrees: int = None): - converted_stimuli_id = f"{stimuli_identifier}--target{target_visual_degrees}--source{source_visual_degrees}" + target_visual_degrees: Union[int, float], source_visual_degrees: Union[int, float, None] = None): + source_degrees_formatted = f"{source_visual_degrees}" if source_visual_degrees is None \ + else f"{source_visual_degrees:.2f}" # make sure we do not try to print a None with 2 decimal places + converted_stimuli_id = f"{stimuli_identifier}--target{target_visual_degrees:.2f}--source{source_degrees_formatted}" source_visual_degrees = _determine_visual_degrees(source_visual_degrees, stimulus_set) target_dir = root_path / converted_stimuli_id if os.path.exists(target_dir): shutil.rmtree(target_dir) - target_dir.mkdir(parents=True, exist_ok=False) + target_dir.mkdir(parents=True, exist_ok=False) image_converter = ImageConverter(target_dir=target_dir) converted_image_paths = {} @@ -72,7 +77,7 @@ class ImageConverter: def __init__(self, target_dir): self._target_dir = Path(target_dir) - def convert_image(self, image_path, source_degrees, target_degrees): + def convert_image(self, image_path, source_degrees: Union[int, float], target_degrees: Union[int, float]): if source_degrees == target_degrees: return image_path ratio = target_degrees / source_degrees @@ -118,4 +123,4 @@ def _center_on_background(self, center_image, background_size, background_color= return image def _write(self, image, target_path): - image.save(target_path) + image.save(target_path) \ No newline at end of file diff --git a/brainscore_vision/model_helpers/brain_transformation/behavior.py b/brainscore_vision/model_helpers/brain_transformation/behavior.py index 35a55a1e6..17f02453d 100644 --- a/brainscore_vision/model_helpers/brain_transformation/behavior.py +++ b/brainscore_vision/model_helpers/brain_transformation/behavior.py @@ -139,8 +139,25 @@ class LabelToImagenetIndices: motorbike_indices = [670, 665] bus_indices = [779, 874, 654] + # added from the Scialom2024 benchmark: + banana_indices = [954] + beanie_indices = [439, 452, 515, 808] + binoculars_indices = [447] + boot_indices = [514] + bowl_indices = [659, 809] + cup_indices = [968] + glasses_indices = [837] + lamp_indices = [470, 607, 818, 846] + pan_indices = [567] + sewingmachine_indices = [786] + shovel_indices = [792] + # truck indices used as defined by Geirhos et al., 2021. + @classmethod def label_to_indices(cls, label): + # for handling multi-word labels given by models or benchmarks + label = label.lower().replace(" ", "") + synset_indices = getattr(cls, f"{label}_indices") return synset_indices diff --git a/brainscore_vision/models/resnet18_imagenet21kP/__init__.py b/brainscore_vision/models/resnet18_imagenet21kP/__init__.py index be444ccee..4b3207d97 100644 --- a/brainscore_vision/models/resnet18_imagenet21kP/__init__.py +++ b/brainscore_vision/models/resnet18_imagenet21kP/__init__.py @@ -1,8 +1,6 @@ from brainscore_vision import model_registry from .model import get_model -MODEL_NAME = "resnet18_imagenet21kP" - -model_registry["resnet18_imagenet21kP-abdulkadir.gokce@epfl.ch"] = lambda: get_model( - MODEL_NAME +model_registry["resnet18_imagenet21kP"] = lambda: get_model( + "resnet18_imagenet21kP" )