-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
THINGS behavioral benchmark and odd-one-out model_helper (#434)
* integrate @linus-md's model-tools PR#70 brain-score/model-tools#70 * Use new load functions * Update * Include working notebook * Add TODOs * More TODOs * Clean up imports * refactor to pass list * Implement ``calculate_similarity_matrix()`` * Implement preliminary ``calculate_choices()`` * Move benchmark draft to this PR * Update * Update benchmark * Update benchmark.py to 2.0 standards * Update hebart2023/test.py to 2.0 standards * Update hebart2023/__init__.py to 2.0 standards * Update benchmark.py * Package triplets in assembly * Add triplet test * Make data compatible with interface * Update sample.ipynb * Update sample.ipynb * Package choices * Update benchmark.py * Update test.py * Update test.py * Update behavior.py * Update test_behavior.py * Update draft * Update * Update behavior.py * Fixed similarit_matrix indexing * Scores for 3333 triplets * sort stimuli * Update sorting * Fix stimulus_paths * Running benchmark * Delete draft.ipynb * Update benchmark.py * Update test.py * Update test_behavior.py * Update behavior.py * add sample triplet * add vectorized numpy choice function to use with full stimulus set * Revert "add vectorized numpy choice function to use with full stimulus set" This reverts commit 5131544. * Add tests * Update benchmark and tests * Finalize behavior tests * Fix benchmark test * Remove slow test * Fix typo in tutorial * Speed up alexnet test * Trigger CI * Update brainscore_vision/benchmarks/hebart2023/test.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/benchmarks/hebart2023/test.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/benchmarks/hebart2023/benchmark.py Co-authored-by: Martin Schrimpf <[email protected]> * Update __init__.py * Update benchmark.py * Update test.py * Explain noise ceiling * Fix ``test_benchmark_registry()`` * Update ceiling * Update test.py * Fix typo * All tests passing again * Update benchmark.py * Update test.py * Add missing import * load assembly and stimulus set inside test methods --------- Co-authored-by: Linus Sommer <[email protected]> Co-authored-by: linus-md <[email protected]>
- Loading branch information
1 parent
5241295
commit 47ff835
Showing
9 changed files
with
281 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from brainscore_vision import benchmark_registry | ||
from .benchmark import Hebart2023Match | ||
|
||
benchmark_registry['Hebart2023-match'] = Hebart2023Match | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from brainio.stimuli import StimulusSet | ||
from brainscore_vision import load_dataset, load_stimulus_set | ||
from brainscore_vision.benchmarks import BenchmarkBase | ||
from brainscore_vision.benchmark_helpers.screen import place_on_screen | ||
from brainscore_vision.model_interface import BrainModel | ||
from brainscore_vision.metrics import Score | ||
|
||
BIBTEX = """@article{10.7554/eLife.82580, | ||
author = {Hebart, Martin N and Contier, Oliver and Teichmann, Lina and Rockter, Adam H and Zheng, Charles Y and Kidder, Alexis and Corriveau, Anna and Vaziri-Pashkam, Maryam and Baker, Chris I}, | ||
journal = {eLife}, | ||
month = {feb}, | ||
pages = {e82580}, | ||
title = {THINGS-data, a multimodal collection of large-scale datasets for investigating object representations in human brain and behavior}, | ||
volume = 12, | ||
year = 2023 | ||
}""" | ||
|
||
class Hebart2023Match(BenchmarkBase): | ||
def __init__(self, similarity_measure='dot'): | ||
self._visual_degrees = 8 | ||
self._number_of_trials = 1 | ||
self._assembly = load_dataset('Hebart2023') | ||
self._stimulus_set = load_stimulus_set('Hebart2023') | ||
|
||
# The noise ceiling was computed by averaging the percentage of participants | ||
# who made the same choice for a given triplet. See the paper for more detail. | ||
super().__init__( | ||
identifier=f'Habart2023Match_{similarity_measure}', version=1, | ||
ceiling_func=lambda: Score(0.6767), | ||
parent='Hebart2023', | ||
bibtex=BIBTEX | ||
) | ||
|
||
def set_number_of_triplets(self, n): | ||
self._assembly = self._assembly[:n] | ||
|
||
def __call__(self, candidate: BrainModel): | ||
# Create the new StimulusSet | ||
self.triplets = np.array([ | ||
self._assembly.coords["image_1"].values, | ||
self._assembly.coords["image_2"].values, | ||
self._assembly.coords["image_3"].values | ||
]).T.reshape(-1, 1) | ||
|
||
stimuli_data = [self._stimulus_set.loc[stim] for stim in self.triplets] | ||
stimuli = pd.concat(stimuli_data) | ||
stimuli.columns = self._stimulus_set.columns | ||
|
||
stimuli = StimulusSet(stimuli) | ||
stimuli.identifier = 'Hebart2023' | ||
stimuli.stimulus_paths = self._stimulus_set.stimulus_paths | ||
stimuli['stimulus_id'] = stimuli['stimulus_id'].astype(int) | ||
|
||
# Prepare the stimuli | ||
candidate.start_task(BrainModel.Task.odd_one_out) | ||
stimuli = place_on_screen( | ||
stimulus_set=stimuli, | ||
target_visual_degrees=candidate.visual_degrees(), | ||
source_visual_degrees=self._visual_degrees | ||
) | ||
|
||
# Run the model | ||
choices = candidate.look_at(stimuli, self._number_of_trials) | ||
|
||
# Score the model | ||
# We chose not to compute error estimates but you could compute them | ||
# by spliting the data into five folds and computing the standard deviation. | ||
correct_choices = choices.values == self._assembly.coords["image_3"].values | ||
raw_score = np.sum(correct_choices)/len(choices) | ||
score = (raw_score - 1/3)/(self.ceiling - 1/3) | ||
score = max(0, score) | ||
score.attrs['raw'] = raw_score | ||
score.attrs['ceiling'] = self.ceiling | ||
return score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pytest | ||
|
||
from brainscore_vision import load_benchmark, load_model | ||
|
||
@pytest.mark.private_access | ||
def test_ceiling(): | ||
benchmark = load_benchmark('Hebart2023-match') | ||
ceiling = benchmark.ceiling | ||
assert ceiling == pytest.approx(0.6767, abs=0.0001) | ||
|
||
@pytest.mark.private_access | ||
def test_alexnet_consistency(): | ||
benchmark = load_benchmark('Hebart2023-match') | ||
benchmark.set_number_of_triplets(n=1000) | ||
model = load_model('alexnet') | ||
score = benchmark(model) | ||
assert score == pytest.approx(0.38, abs=0.02) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,42 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import brainscore | ||
from brainio.stimuli import StimulusSet | ||
|
||
from brainscore_vision import load_stimulus_set, load_dataset | ||
from brainscore_vision.model_interface import BehavioralAssembly | ||
|
||
|
||
@pytest.mark.memory_intense | ||
@pytest.mark.private_access | ||
class TestHebart2023: | ||
assembly = brainscore.get_assembly('Hebart2023') | ||
stimulus_set = brainscore.get_stimulus_set("Hebart2023") | ||
|
||
def test_assembly(self): | ||
stimulus_id = self.assembly.coords["stimulus_id"] | ||
triplet_id = self.assembly.coords["triplet_id"] | ||
assembly = load_dataset('Hebart2023') | ||
|
||
stimulus_id = assembly.coords["stimulus_id"] | ||
triplet_id = assembly.coords["triplet_id"] | ||
assert len(stimulus_id) == len(triplet_id) == 453642 | ||
assert len(np.unique(stimulus_id)) == 1854 | ||
|
||
image_1 = self.assembly.coords["image_1"] | ||
image_2 = self.assembly.coords["image_2"] | ||
image_3 = self.assembly.coords["image_3"] | ||
image_1 = assembly.coords["image_1"] | ||
image_2 = assembly.coords["image_2"] | ||
image_3 = assembly.coords["image_3"] | ||
assert len(image_1) == len(image_2) == len(image_3) == 453642 | ||
|
||
def test_assembly_stimulusset_ids_match(self): | ||
stimulusset_ids = self.stimulus_set['stimulus_id'] | ||
stimulus_set = load_stimulus_set("Hebart2023") | ||
assembly = load_dataset('Hebart2023') | ||
|
||
stimulusset_ids = stimulus_set['stimulus_id'] | ||
for assembly_stimulusid in ['image_1', 'image_2', 'image_3']: | ||
assembly_values = self.assembly[assembly_stimulusid].values | ||
assembly_values = assembly[assembly_stimulusid].values | ||
assert set(assembly_values) == set(stimulusset_ids), \ | ||
f"Assembly stimulus id reference '{assembly_stimulusid}' does not match stimulus_set" | ||
|
||
def test_stimulus_set(self): | ||
assert len(self.stimulus_set) == 1854 | ||
stimulus_set = load_stimulus_set("Hebart2023") | ||
assert len(stimulus_set) == 1854 | ||
assert {'unique_id', 'stimulus_id', 'filename', | ||
'WordNet_ID', 'Wordnet_ID2', 'Wordnet_ID3', 'Wordnet_ID4', 'WordNet_synonyms', | ||
'freq_1', 'freq_2', 'top_down_1', 'top_down_2', 'bottom_up', 'word_freq', 'word_freq_online', | ||
'example_image', 'dispersion', 'dominant_part', 'rank'} == set(self.stimulus_set.columns) | ||
assert isinstance(self.stimulus_set, StimulusSet) | ||
'example_image', 'dispersion', 'dominant_part', 'rank'} == set(stimulus_set.columns) | ||
assert isinstance(stimulus_set, StimulusSet) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters