diff --git a/brainscore_vision/metrics/accuracy_distance/metric.py b/brainscore_vision/metrics/accuracy_distance/metric.py index fb31a7280..eb47e3bba 100644 --- a/brainscore_vision/metrics/accuracy_distance/metric.py +++ b/brainscore_vision/metrics/accuracy_distance/metric.py @@ -10,17 +10,52 @@ class AccuracyDistance(Metric): """ - Computes the accuracy distance using the relative distance between the source and target accuracies, adjusted - for the maximum possible difference between the two accuracies. + Computes the accuracy distance using the relative distance between the + source and target accuracies, adjusted for the maximum possible + difference between the two accuracies. By default, the distance is computed + from a single accuracy score on the entire BehavioralAssembly. However, + the distance can also be computed on a condition-wise basis using the + 'variables' argument. The advantage of the condition-wise approach is that + it can separate two models with identical overall accuracy if one exhibits a + more target-like pattern of performance across conditions. """ - def __call__(self, source: BehavioralAssembly, target: BehavioralAssembly) -> Score: + def __call__(self, source: BehavioralAssembly, target: + BehavioralAssembly, variables: tuple=()) -> Score: """Target should be the entire BehavioralAssembly, containing truth values.""" subjects = self.extract_subjects(target) subject_scores = [] for subject in subjects: subject_assembly = target.sel(subject=subject) - subject_score = self.compare_single_subject(source, subject_assembly) + + # compute single score across the entire dataset + if len(variables) == 0: + subject_score = self.compare_single_subject(source, subject_assembly) + + # compute scores for each condition, then average + else: + cond_scores = [] + + # get iterator across all combinations of variables + if len(variables) == 1: + conditions = set(subject_assembly[variables[0]].values) + conditions = [[c] for c in conditions] # to mimic itertools.product + else: + conditions = itertools.product( + *[set(subject_assembly[v].values) for v in variables]) + + # loop over conditions and compute scores + for cond in conditions: + indexers = {v: cond[i] for i, v in enumerate(variables)} + subject_cond_assembly = subject_assembly.sel(**indexers) + source_cond_assembly = source.sel(**indexers) + # to accomodate unbalanced designs, skip combinations of + # variables that don't exist in both assemblies + if len(subject_cond_assembly) and len(source_cond_assembly): + cond_scores.append(self.compare_single_subject( + source_cond_assembly, subject_cond_assembly)) + subject_score = Score(np.mean(cond_scores)) + subject_score = subject_score.expand_dims('subject') subject_score['subject'] = 'subject', [subject] subject_scores.append(subject_score) diff --git a/brainscore_vision/metrics/accuracy_distance/test.py b/brainscore_vision/metrics/accuracy_distance/test.py index 2fc15b792..d6414b790 100644 --- a/brainscore_vision/metrics/accuracy_distance/test.py +++ b/brainscore_vision/metrics/accuracy_distance/test.py @@ -12,6 +12,20 @@ def test_score(): assert score == approx(0.74074074) +def test_score_single_variable(): + assembly = _make_data() + metric = load_metric('accuracy_distance') + score = metric(assembly.sel(subject='C'), assembly, ('condition',)) + assert score == approx(0.55555556) + + +def test_score_multi_variable(): + assembly = _make_data() + metric = load_metric('accuracy_distance') + score = metric(assembly.sel(subject='C'), assembly, ('condition','animacy')) + assert score == approx(0.55555556) + + def test_has_error(): assembly = _make_data() metric = load_metric('accuracy_distance') @@ -38,5 +52,6 @@ def _make_data(): coords={'stimulus_id': ('presentation', np.resize(np.arange(9), 9 * 3)), 'truth': ('presentation', np.resize(['dog', 'cat', 'chair'], 9 * 3)), 'condition': ('presentation', np.resize([1, 1, 1, 2, 2, 2, 3, 3, 3], 9 * 3)), + 'animacy': ('presentation', np.resize(['animate', 'animate', 'inanimate'], 9 * 3)), 'subject': ('presentation', ['A'] * 9 + ['B'] * 9 + ['C'] * 9)}, dims=['presentation'])