diff --git a/brainscore_vision/benchmarks/majajhong2015/__init__.py b/brainscore_vision/benchmarks/majajhong2015/__init__.py index 24fe8651e..5ae8988fd 100644 --- a/brainscore_vision/benchmarks/majajhong2015/__init__.py +++ b/brainscore_vision/benchmarks/majajhong2015/__init__.py @@ -11,3 +11,8 @@ benchmark_registry['MajajHong2015public.V4-pls'] = MajajHongV4PublicBenchmark benchmark_registry['MajajHong2015public.IT-pls'] = MajajHongITPublicBenchmark + +# temporal +from .benchmark import MajajHongV4TemporalPublicBenchmark, MajajHongITTemporalPublicBenchmark +benchmark_registry['MajajHong2015public.V4-temporal-pls'] = lambda: MajajHongV4TemporalPublicBenchmark(time_interval=10) +benchmark_registry['MajajHong2015public.IT-temporal-pls'] = lambda: MajajHongITTemporalPublicBenchmark(time_interval=10) diff --git a/brainscore_vision/benchmarks/majajhong2015/benchmark.py b/brainscore_vision/benchmarks/majajhong2015/benchmark.py index 766f5c93f..5270ab7af 100644 --- a/brainscore_vision/benchmarks/majajhong2015/benchmark.py +++ b/brainscore_vision/benchmarks/majajhong2015/benchmark.py @@ -1,7 +1,8 @@ from brainscore_core import Metric from brainscore_vision import load_metric, Ceiling, load_ceiling, load_dataset -from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition +from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition, apply_keep_attrs +from brainscore_vision.model_helpers.brain_transformation.temporal import assembly_time_align VISUAL_DEGREES = 8 NUMBER_OF_TRIALS = 50 @@ -20,13 +21,14 @@ eprint = {https://www.jneurosci.org/content/35/39/13402.full.pdf}, journal = {Journal of Neuroscience}}""" -pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=dict(stratification_coord='object_name')) - +crossvalidation_kwargs = dict(stratification_coord='object_name') +pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=crossvalidation_kwargs) +spantime_pls_metric = lambda: load_metric('spantime_pls', crossvalidation_kwargs=crossvalidation_kwargs) def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suffix: str, - similarity_metric: Metric, ceiler: Ceiling): - assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access) - assembly = load_assembly(average_repetitions=True, region=region, access=access) + similarity_metric: Metric, ceiler: Ceiling, time_interval: float = None): + assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access, time_interval=time_interval) + assembly = load_assembly(average_repetitions=True, region=region, access=access, time_interval=time_interval) benchmark_identifier = f'MajajHong2015.{region}' + ('.public' if access == 'public' else '') return NeuralBenchmark(identifier=f'{benchmark_identifier}-{identifier_metric_suffix}', version=3, assembly=assembly, similarity_metric=similarity_metric, @@ -60,13 +62,35 @@ def MajajHongITPublicBenchmark(): ceiler=load_ceiling('internal_consistency')) -def load_assembly(average_repetitions, region, access='private'): - assembly = load_dataset(f'MajajHong2015.{access}') +def MajajHongV4TemporalPublicBenchmark(time_interval: float = None): + return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls', + similarity_metric=spantime_pls_metric(), time_interval=time_interval, + ceiler=load_ceiling('internal_consistency_temporal')) + + +def MajajHongITTemporalPublicBenchmark(time_interval: float = None): + return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls', + similarity_metric=spantime_pls_metric(), time_interval=time_interval, + ceiler=load_ceiling('internal_consistency_temporal')) + + +def load_assembly(average_repetitions: bool, region: str, access: str = 'private', time_interval: float = None): + temporal = time_interval is not None + if not temporal: + assembly = load_dataset(f'MajajHong2015.{access}') + assembly = assembly.squeeze("time_bin") + else: + assembly = load_dataset(f'MajajHong2015.temporal.{access}') + assembly = assembly.__class__(assembly) + target_time_bins = [ + (t, t+time_interval) for t in range(0, assembly.time_bin_end.max().item()-time_interval, time_interval) + ] + assembly = apply_keep_attrs(assembly, lambda assembly: assembly_time_align(assembly, target_time_bins)) + assembly = assembly.sel(region=region) assembly['region'] = 'neuroid', [region] * len(assembly['neuroid']) - assembly = assembly.squeeze("time_bin") assembly.load() - assembly = assembly.transpose('presentation', 'neuroid') + assembly = assembly.transpose('presentation', 'neuroid', ...) if average_repetitions: assembly = average_repetition(assembly) return assembly diff --git a/brainscore_vision/metric_helpers/temporal.py b/brainscore_vision/metric_helpers/temporal.py new file mode 100644 index 000000000..0c110b9f2 --- /dev/null +++ b/brainscore_vision/metric_helpers/temporal.py @@ -0,0 +1,119 @@ +import xarray as xr +import numpy as np + +from brainscore_vision.benchmark_helpers.neural_common import Score +from brainscore_vision.metric_helpers.transformations import standard_error_of_the_mean + +from .xarray_utils import apply_over_dims, recursive_op + + +# take the mean of scores (medians of single neuron scores) over time + + +def average_over_presentation(score: Score) -> Score: + raw = score + score = raw.mean('presentation') + score.attrs['raw'] = raw + return score + + +# PerOps is applied to every slice/chunk of the xarray along the specified dimensions +class PerOps: + def __init__(self, callable, dims, check_coords=[]): + # for coordinate checking, they are supposed to be the same across assemblies + self.dims = dims + self.callable = callable + self.check_coords = check_coords + + def __call__(self, *asms): + for check_coord in self.check_coords: + asms = [asm.sortby(check_coord) for asm in asms] + for asm in asms[1:]: + assert (asm[check_coord].values == asms[0][check_coord].values).all() + ret = apply_over_dims(self.callable, *asms, dims=self.dims) + return ret + + +# SpanOps aggregates specified dimensions to one dimension +class SpanOps: + def __init__(self, callable, source_dims, aggregated_dim, resample=False): + # if resample, randomly choose samples from the aggregated dimension, + # whose size is the same as the assembly.sizes[aggregated_dim] + self.source_dims = source_dims + self.aggregated_dim = aggregated_dim + self.callable = callable + self.resample = resample + + def __call__(self, *asms): + asms = [self._stack(asm) for asm in asms] + return self.callable(*asms) + + def _stack(self, assembly): + assembly_type = type(assembly) + size = assembly.sizes[self.aggregated_dim] + assembly = xr.DataArray(assembly) # xarray cannot deal with stacking MultiIndex (pydata/xarray#1554) + assembly = assembly.reset_index(self.source_dims) + assembly = assembly.rename({dim:dim+"_" for dim in self.source_dims}) # we'll call stacked timebins "presentation" + assembly = assembly.stack({self.aggregated_dim : [dim+"_" for dim in self.source_dims]}) + if self.resample: + indices = np.random.randint(0, assembly.sizes[self.aggregated_dim], size) + assembly = assembly.isel({self.aggregated_dim: indices}) + return assembly_type(assembly) + +class PerTime(PerOps): + def __init__(self, callable, time_dim="time_bin", check_coord="time_bin_start", **kwargs): + self.time_bin = time_dim + super().__init__(callable, dims=[time_dim], check_coords=[check_coord], **kwargs) + +class PerPresentation(PerOps): + def __init__(self, callable, presentation_dim="presentation", check_coord="stimulus_id", **kwargs): + self.presentation_dim = presentation_dim + super().__init__(callable, dims=[presentation_dim], check_coords=[check_coord], **kwargs) + +class PerNeuroid(PerOps): + def __init__(self, callable, neuroid_dim="neuroid", check_coord="neuroid_id", **kwargs): + self.neuroid_dim = neuroid_dim + super().__init__(callable, dims=[neuroid_dim], check_coords=[check_coord], **kwargs) + +class SpanTime(SpanOps): + def __init__(self, callable, time_dim="time_bin", presentation_dim="presentation", resample=False): + self.time_dim = time_dim + self.presentation_dim = presentation_dim + source_dims = [self.time_dim, self.presentation_dim] + aggregated_dim = self.presentation_dim + super().__init__(callable, source_dims, aggregated_dim, resample=resample) + +class SpanTimeRegression: + """ + Fits a regression with weights shared across the time bins. + """ + + def __init__(self, regression): + self._regression = regression + + def fit(self, source, target): + assert (source['time_bin'].values == target['time_bin'].values).all() + SpanTime(self._regression.fit)(source, target) + + def predict(self, source): + return PerTime(self._regression.predict)(source) + +class PerTimeRegression: + """ + Fits a regression with different weights for each time bins. + """ + + def __init__(self, regression): + self._regression = regression + + def fit(self, source, target): + # Lazy fit until predict + assert (source['time_bin'].values == target['time_bin'].values).all() + self._train_source = source + self._train_target = target + + def predict(self, source): + def fit_predict(train_source, train_target, test_source): + self._regression.fit(train_source, train_target) + return self._regression.predict(test_source) + return PerTime(fit_predict)(self._train_source, self._train_target, source) \ No newline at end of file diff --git a/brainscore_vision/metric_helpers/xarray_utils.py b/brainscore_vision/metric_helpers/xarray_utils.py index ce67654ff..8998b6003 100644 --- a/brainscore_vision/metric_helpers/xarray_utils.py +++ b/brainscore_vision/metric_helpers/xarray_utils.py @@ -1,4 +1,5 @@ import numpy as np +import xarray as xr from brainio.assemblies import NeuroidAssembly, array_is_element, walk_coords from brainscore_vision.metric_helpers import Defaults @@ -90,3 +91,61 @@ def __call__(self, prediction, target): for coord, dims, values in walk_coords(target) if dims == neuroid_dims}, dims=neuroid_dims) return result + + +# ops that also applies to attrs (and attrs of attrs), which are xarrays +def recursive_op(*arrs, op=lambda x:x): + # the attrs structure of each arr must be the same + val = op(*arrs) + attrs = arrs[0].attrs + for attr in attrs: + attr_val = arrs[0].attrs[attr] + if isinstance(attr_val, xr.DataArray): + attr_arrs = [arr.attrs[attr] for arr in arrs] + attr_val = recursive_op(*attr_arrs, op=op) + val.attrs[attr] = attr_val + return val + + +# apply a callable to every slice of the xarray along the specified dimensions +def apply_over_dims(callable, *asms, dims, njobs=-1): + asms = [asm.transpose(*dims, ...) for asm in asms] + sizes = [asms[0].sizes[dim] for dim in dims] + + def apply_helper(sizes, dims, *asms): + xarr = [] + attrs = {} + size = sizes[0] + rsizes = sizes[1:] + dim = dims[0] + rdims = dims[1:] + + if len(sizes) == 1: + # parallel execution on the last applied dimension + from joblib import Parallel, delayed + results = Parallel(n_jobs=njobs)(delayed(callable)(*[asm.isel({dim:s}) for asm in asms]) for s in range(size)) + else: + results = [] + for s in range(size): + arr = apply_helper(rsizes, rdims, *[asm.isel({dim:s}) for asm in asms]) + results.append(arr) + + for arr in results: + if arr is not None: + for k,v in arr.attrs.items(): + assert isinstance(v, xr.DataArray) + attrs.setdefault(k, []).append(v.expand_dims(dim)) + xarr.append(arr) + + if not xarr: + return + else: + xarr = xr.concat(xarr, dim=dim) + attrs = {k: xr.concat(vs, dim=dim) for k,vs in attrs.items()} + xarr.coords[dim] = asms[0].coords[dim] + for k,v in attrs.items(): + attrs[k].coords[dim] = asms[0].coords[dim] + xarr.attrs[k] = attrs[k] + return xarr + + return apply_helper(sizes, dims, *asms) \ No newline at end of file diff --git a/brainscore_vision/metrics/internal_consistency/__init__.py b/brainscore_vision/metrics/internal_consistency/__init__.py index bd71776be..ae6a41ea6 100644 --- a/brainscore_vision/metrics/internal_consistency/__init__.py +++ b/brainscore_vision/metrics/internal_consistency/__init__.py @@ -1,4 +1,8 @@ from brainscore_vision import metric_registry from .ceiling import InternalConsistency +from brainscore_vision.metric_helpers.temporal import PerTime + + metric_registry['internal_consistency'] = InternalConsistency +metric_registry['internal_consistency_temporal'] = lambda *args, **kwargs: PerTime(InternalConsistency(*args, **kwargs)) \ No newline at end of file diff --git a/brainscore_vision/metrics/regression_correlation/__init__.py b/brainscore_vision/metrics/regression_correlation/__init__.py index 2f8019b3f..691e82685 100644 --- a/brainscore_vision/metrics/regression_correlation/__init__.py +++ b/brainscore_vision/metrics/regression_correlation/__init__.py @@ -11,6 +11,15 @@ metric_registry['linear_predictivity'] = lambda *args, **kwargs: CrossRegressedCorrelation( regression=linear_regression(), correlation=pearsonr_correlation(), *args, **kwargs) +# temporal metrics +from .metric import SpanTimeCrossRegressedCorrelation + +metric_registry['spantime_pls'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation( + regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs) +metric_registry['spantime_ridge'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation( + regression=ridge_regression(), correlation=pearsonr_correlation(), *args, **kwargs) + + BIBTEX = """@article{schrimpf2018brain, title={Brain-score: Which artificial neural network for object recognition is most brain-like?}, author={Schrimpf, Martin and Kubilius, Jonas and Hong, Ha and Majaj, Najib J and Rajalingham, Rishi and Issa, Elias B and Kar, Kohitij and Bashivan, Pouya and Prescott-Roy, Jonathan and Geiger, Franziska and others}, diff --git a/brainscore_vision/metrics/regression_correlation/metric.py b/brainscore_vision/metrics/regression_correlation/metric.py index 365f63868..a09ba03e0 100644 --- a/brainscore_vision/metrics/regression_correlation/metric.py +++ b/brainscore_vision/metrics/regression_correlation/metric.py @@ -8,6 +8,7 @@ from brainscore_core.metrics import Metric, Score from brainscore_vision.metric_helpers.transformations import CrossValidation from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation +from brainscore_vision.metric_helpers.temporal import SpanTimeRegression, PerTime class CrossRegressedCorrelation(Metric): @@ -65,6 +66,15 @@ def predict(self, X): return Ypred +# make the crc to consider time as a sample dimension +def SpanTimeCrossRegressedCorrelation(regression, correlation, *args, **kwargs): + return CrossRegressedCorrelation( + regression=SpanTimeRegression(regression), + correlation=PerTime(correlation), + *args, **kwargs + ) + + def pls_regression(regression_kwargs=None, xarray_kwargs=None): regression_defaults = dict(n_components=25, scale=False) regression_kwargs = {**regression_defaults, **(regression_kwargs or {})} diff --git a/tests/test_metric_helpers/test_temporal.py b/tests/test_metric_helpers/test_temporal.py new file mode 100644 index 000000000..64dffe8de --- /dev/null +++ b/tests/test_metric_helpers/test_temporal.py @@ -0,0 +1,80 @@ +import numpy as np +import scipy.stats +from pytest import approx +from sklearn.linear_model import LinearRegression + +from brainio.assemblies import NeuroidAssembly +from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation +from brainscore_vision.metric_helpers.temporal import PerTime, SpanTime, PerTimeRegression, SpanTimeRegression + + +class TestMetricHelpers: + def test_pertime_ops(self): + jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), + coords={'stimulus_id': ('presentation', list(reversed(range(500)))), + 'image_meta': ('presentation', [0] * 500), + 'neuroid_id': ('neuroid', list(reversed(range(10)))), + 'neuroid_meta': ('neuroid', [0] * 10), + 'time_bin_start': ('time_bin', np.arange(0, 400, 20)), + 'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, + dims=['presentation', 'neuroid', 'time_bin']) + mean_neuroid = lambda arr: arr.mean('neuroid') + pertime_mean_neuroid = PerTime(mean_neuroid) + output = pertime_mean_neuroid(jumbled_source) + output = output.transpose('presentation', 'time_bin') + target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('neuroid') + assert (output == approx(target)).all().item() + + def test_spantime_ops(self): + jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), + coords={'stimulus_id': ('presentation', list(reversed(range(500)))), + 'image_meta': ('presentation', [0] * 500), + 'neuroid_id': ('neuroid', list(reversed(range(10)))), + 'neuroid_meta': ('neuroid', [0] * 10), + 'time_bin_start': ('time_bin', np.arange(0, 400, 20)), + 'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, + dims=['presentation', 'neuroid', 'time_bin']) + mean_presentation = lambda arr: arr.mean("presentation") + spantime_mean_presentation = SpanTime(mean_presentation) + output = spantime_mean_presentation(jumbled_source) + output = output.transpose('neuroid') + target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('presentation').mean('time_bin') + assert (output == approx(target)).all().item() + + def test_pertime_regression(self): + jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), + coords={'stimulus_id': ('presentation', list(reversed(range(500)))), + 'image_meta': ('presentation', [0] * 500), + 'neuroid_id': ('neuroid', list(reversed(range(10)))), + 'neuroid_meta': ('neuroid', [0] * 10), + 'time_bin_start': ('time_bin', np.arange(0, 400, 20)), + 'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, + dims=['presentation', 'neuroid', 'time_bin']) + target = jumbled_source.sortby(['stimulus_id', 'neuroid_id']) + pertime_regression = PerTimeRegression(XarrayRegression(LinearRegression())) + pertime_regression.fit(jumbled_source, target) + prediction = pertime_regression.predict(jumbled_source) + prediction = prediction.transpose(*target.dims) + # do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata. + np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values, + target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values) + + + def test_spantime_regression(self): + jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), + coords={'stimulus_id': ('presentation', list(reversed(range(500)))), + 'image_meta': ('presentation', [0] * 500), + 'neuroid_id': ('neuroid', list(reversed(range(10)))), + 'neuroid_meta': ('neuroid', [0] * 10), + 'time_bin_start': ('time_bin', np.arange(0, 400, 20)), + 'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, + dims=['presentation', 'neuroid', 'time_bin']) + target = jumbled_source.sortby(['stimulus_id', 'neuroid_id']) + spantime_regression = SpanTimeRegression(XarrayRegression(LinearRegression())) + spantime_regression.fit(jumbled_source, target) + prediction = spantime_regression.predict(jumbled_source) + prediction = prediction.transpose(*target.dims) + # do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata. + np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values, + target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values) +