Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated accuracy distance metric to accomodate condition-wise approaches #1217

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions brainscore_vision/metrics/accuracy_distance/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,52 @@

class AccuracyDistance(Metric):
"""
Computes the accuracy distance using the relative distance between the source and target accuracies, adjusted
for the maximum possible difference between the two accuracies.
Computes the accuracy distance using the relative distance between the
source and target accuracies, adjusted for the maximum possible
difference between the two accuracies. By default, the distance is computed
from a single accuracy score on the entire BehavioralAssembly. However,
the distance can also be computed on a condition-wise basis using the
'variables' argument. The advantage of the condition-wise approach is that
it can separate two models with identical overall accuracy if one exhibits a
more target-like pattern of performance across conditions.
"""
def __call__(self, source: BehavioralAssembly, target: BehavioralAssembly) -> Score:
def __call__(self, source: BehavioralAssembly, target:
BehavioralAssembly, variables: tuple=()) -> Score:
"""Target should be the entire BehavioralAssembly, containing truth values."""

subjects = self.extract_subjects(target)
subject_scores = []
for subject in subjects:
subject_assembly = target.sel(subject=subject)
subject_score = self.compare_single_subject(source, subject_assembly)

# compute single score across the entire dataset
if len(variables) == 0:
subject_score = self.compare_single_subject(source, subject_assembly)

# compute scores for each condition, then average
else:
cond_scores = []

# get iterator across all combinations of variables
if len(variables) == 1:
conditions = set(subject_assembly[variables[0]].values)
conditions = [[c] for c in conditions] # to mimic itertools.product
else:
conditions = itertools.product(
*[set(subject_assembly[v].values) for v in variables])

# loop over conditions and compute scores
for cond in conditions:
indexers = {v: cond[i] for i, v in enumerate(variables)}
subject_cond_assembly = subject_assembly.sel(**indexers)
source_cond_assembly = source.sel(**indexers)
# to accomodate unbalanced designs, skip combinations of
# variables that don't exist in both assemblies
if len(subject_cond_assembly) and len(source_cond_assembly):
cond_scores.append(self.compare_single_subject(
source_cond_assembly, subject_cond_assembly))
subject_score = Score(np.mean(cond_scores))

subject_score = subject_score.expand_dims('subject')
subject_score['subject'] = 'subject', [subject]
subject_scores.append(subject_score)
Expand Down
15 changes: 15 additions & 0 deletions brainscore_vision/metrics/accuracy_distance/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ def test_score():
assert score == approx(0.74074074)


def test_score_single_variable():
assembly = _make_data()
metric = load_metric('accuracy_distance')
score = metric(assembly.sel(subject='C'), assembly, ('condition',))
assert score == approx(0.55555556)


def test_score_multi_variable():
assembly = _make_data()
metric = load_metric('accuracy_distance')
score = metric(assembly.sel(subject='C'), assembly, ('condition','animacy'))
assert score == approx(0.55555556)


def test_has_error():
assembly = _make_data()
metric = load_metric('accuracy_distance')
Expand All @@ -38,5 +52,6 @@ def _make_data():
coords={'stimulus_id': ('presentation', np.resize(np.arange(9), 9 * 3)),
'truth': ('presentation', np.resize(['dog', 'cat', 'chair'], 9 * 3)),
'condition': ('presentation', np.resize([1, 1, 1, 2, 2, 2, 3, 3, 3], 9 * 3)),
'animacy': ('presentation', np.resize(['animate', 'animate', 'inanimate'], 9 * 3)),
'subject': ('presentation', ['A'] * 9 + ['B'] * 9 + ['C'] * 9)},
dims=['presentation'])
Loading